def make_image_grid()

in python/mxboard/utils.py [0:0]


def make_image_grid(tensor, nrow=8, padding=2, normalize=False, norm_range=None,
                    scale_each=False, pad_value=0, square_image=False):
    """Make a grid of images. This is an MXNet version of torchvision.utils.make_grid
    Ref: https://github.com/pytorch/vision/blob/master/torchvision/utils.py

    Parameters
    ----------
        tensor : `NDArray` or list of `NDArray`s
            Input image(s) in the format of HW, CHW, or NCHW.
        nrow : int
            Number of images displayed in each row of the grid. The Final grid size is
            (batch_size / `nrow`, `nrow`) when square_image is False; otherwise, (`nrow`, `nrow`).
        padding : int
            Padding value for each image in the grid.
        normalize : bool
            If True, shift the image to the range (0, 1), by subtracting the
            minimum and dividing by the maximum pixel value.
        norm_range : tuple
            Tuple of (min, max) where min and max are numbers. These numbers are used
            to normalize the image. By default, `min` and `max` are computed from the `tensor`.
        scale_each : bool
            If True, scale each image in the batch of images separately rather than the
            `(min, max)` over all images.
        pad_value : float
            Value for the padded pixels.
        square_image : bool
            If True, force the generated image grid to be strictly square (the last
            row of the image grid may be entire blank as a result).

    Returns
    -------
    NDArray
        A image grid made of the input images.
    """
    if not isinstance(tensor, NDArray) or not (isinstance(tensor, NDArray) and
                                               all(isinstance(t, NDArray) for t in tensor)):
        raise TypeError('MXNet NDArray or list of NDArrays expected, got {}'.format(
            str(type(tensor))))

    # if list of tensors, convert to a 4D mini-batch Tensor
    if isinstance(tensor, list):
        tensor = op.stack(tensor, axis=0)

    if tensor.ndim <= 1 or tensor.ndim > 4:
        raise ValueError('expected 2D, 3D, or 4D NDArrays, while received ndim={}'.format(
            tensor.ndim))

    if tensor.ndim == 2:  # single image H x W
        tensor = tensor.reshape(((1,) + tensor.shape))
    if tensor.ndim == 3:  # single image
        if tensor.shape[0] == 1:  # if single-channel, convert to 3-channel
            tensor = op.concat(*(tensor, tensor, tensor), dim=0)
        tensor = tensor.reshape((1,) + tensor.shape)
    if tensor.ndim == 4 and tensor.shape[1] == 1:  # single-channel images
        tensor = op.concat(*(tensor, tensor, tensor), dim=1)

    if normalize is True:
        tensor = tensor.copy()  # avoid modifying tensor in-place
        if norm_range is not None:
            assert isinstance(norm_range, tuple) and len(norm_range) == 2, \
                "norm_range has to be a tuple (min, max) if specified. min and max are numbers"

        def norm_ip(img, val_min, val_max):
            op.clip(img, a_min=val_min, a_max=val_max, out=img)
            img -= val_min
            img /= (val_max - val_min)

        def norm_range_helper(t, val_range):
            if val_range is not None:
                norm_ip(t, val_range[0], val_range[1])
            else:
                norm_ip(t, t.min().asscalar(), t.max().asscalar())

        if scale_each is True:
            for t in tensor:  # loop over mini-batch dimension
                norm_range_helper(t, norm_range)
        else:
            norm_range_helper(tensor, norm_range)

    # if single image, just return
    if tensor.shape[0] == 1:
        return tensor.squeeze(axis=0)

    # make the batch of images into a grid
    nmaps = tensor.shape[0]
    xmaps = min(nrow, nmaps)
    ymaps = xmaps if square_image else int(np.ceil(float(nmaps) / xmaps))
    height, width = int(tensor.shape[2] + padding), int(tensor.shape[3] + padding)
    grid = nd.empty(shape=(3, height * ymaps + padding, width * xmaps + padding),
                    dtype=tensor.dtype, ctx=tensor.context)
    grid[:] = pad_value
    k = 0
    for y in range(ymaps):
        for x in range(xmaps):
            if k >= nmaps:
                break
            start1 = y * height + padding
            end1 = start1 + height - padding
            start2 = x * width + padding
            end2 = start2 + width - padding
            grid[:, start1:end1, start2:end2] = tensor[k]
            k = k + 1
    return grid