def geom_imshow()

in python-package/lets_plot/plot/geom_imshow_.py [0:0]


def geom_imshow(image_data, cmap=None, *,
                norm=None, alpha=None,
                vmin=None, vmax=None,
                extent=None,
                compression=None,
                show_legend=True,
                color_by="paint_c",
                cguide=None
                ):
    """
    Display an image specified by an ndarray with shape:

    - (M, N) - greyscale image
    - (M, N, 3) - color RGB image
    - (M, N, 4) - color RGBA image with an alpha channel

    This geom is not as flexible as `geom_raster() <https://lets-plot.org/python/pages/api/lets_plot.geom_raster.html>`__
    or `geom_tile() <https://lets-plot.org/python/pages/api/lets_plot.geom_tile.html>`__
    but vastly superior in terms of rendering efficiency.   

    Parameters
    ----------
    image_data : ndarray
        Specify image type, size, and pixel values.
        Supported array shapes are:

        - (M, N): an image with scalar data. The values are mapped to colors (greys by default) using normalization. See parameters ``norm``, ``cmap``, ``vmin``, ``vmax``.
        - (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
        - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int).

        The first two dimensions (M, N) define the rows and columns of the image.
        Out-of-range values are clipped.
    cmap : str or list, optional
        Name of colormap or a list of colors.
        If a string, it should be the name of a colormap supported by the
        Palettable package (https://github.com/jiffyclub/palettable),
        for example, "viridis", "magma", "plasma", "inferno".
        If a list, it should contain color strings in hex ('#RRGGBB'),
        'rgb(r, g, b)', or 'rgba(r, g, b, a)' format with int components (0-255).
        The greyscale values will be quantized to map to the provided colors.
        This parameter is ignored for RGB(A) images.
    norm : bool
        True (default) - luminance values in greyscale images will be scaled to [0-255] range using a linear scaler.
        False - disables scaling of luminance values in greyscale images.
        This parameter is ignored for RGB(A) images.
    alpha : float, optional
        The alpha blending value, between 0 (transparent) and 1 (opaque).
    vmin, vmax : number, optional
        Define the data range used for luminance normalization in greyscale images.
        This parameter is ignored for RGB(A) images or if parameter ``norm=False``.
    extent : list of 4 numbers: [left, right, bottom, top], optional
        Define the image's bounding box in terms of the "data coordinates".

        - ``left``, ``right``: coordinates of the pixels' outer edges along the x-axis for pixels in the 1st and the last column.
        - ``bottom``, ``top``: coordinates of the pixels' outer edges along the y-axis for pixels in the 1st and the last row.

        The default is: [-0.5, ncol-0.5, -0.5, nrow-0.5]
    compression : int, optional
        The compression level to be used by the ``zlib`` module.
        Values from 0 (no compression) to 9 (highest).
        Value None means that the ``zlib`` module uses
        the default level of compression (which is generally acceptable).
    show_legend : bool, default=True
        Greyscale images only.
        False - do not show the legend for this layer.
    color_by : {'fill', 'color', 'paint_a', 'paint_b', 'paint_c'}, default='paint_c'
        Define the color aesthetic used by the legend shown for a greyscale image.
    cguide : optional
        A result of `guide_colorbar() <https://lets-plot.org/python/pages/api/lets_plot.guide_colorbar.html>`__ call.
        Use to customize the colorbar for greyscale images.

    Returns
    -------
    ``LayerSpec``
        Geom object specification.

    Notes
    -----
    This geom doesn't understand any aesthetics.
    It doesn't support color scales either.

    Examples
    --------
    .. jupyter-execute::
        :linenos:
        :emphasize-lines: 6

        import numpy as np
        from lets_plot import *
        LetsPlot.setup_html()
        np.random.seed(42)
        image = np.random.randint(256, size=(64, 64, 4))
        ggplot() + geom_imshow(image)

    |

    .. jupyter-execute::
        :linenos:
        :emphasize-lines: 7

        import numpy as np
        from lets_plot import *
        LetsPlot.setup_html()
        n = 64
        image = 256 * np.linspace(np.linspace(0, .5, n), \\
                                  np.linspace(.5, .5, n), n)
        ggplot() + geom_imshow(image, norm=False)

    |

    .. jupyter-execute::
        :linenos:
        :emphasize-lines: 6

        import numpy as np
        from lets_plot import *
        LetsPlot.setup_html()
        np.random.seed(42)
        image = np.random.normal(size=(64, 64))
        ggplot() + geom_imshow(image, vmin=-1, vmax=1)

    """

    if png is None:
        raise ValueError("pypng is not installed")

    if not is_ndarray(image_data):
        raise ValueError("Invalid image_data: ndarray is expected but was {}".format(type(image_data)))

    if image_data.ndim not in (2, 3):
        raise ValueError(
            "Invalid image_data: 2d or 3d array is expected but was {}-dimensional".format(image_data.ndim))

    if alpha is not None:
        if not (0 <= alpha <= 1):
            raise ValueError(
                "Invalid alpha: expected float in range [0..1] but was {}".format(alpha))

    if compression is not None:
        if not (0 <= compression <= 9):
            raise ValueError(
                "Invalid compression: expected integer in range [0..9] but was {}".format(compression))

    greyscale = (image_data.ndim == 2)
    if greyscale:
        # Greyscale image

        has_nan = numpy.isnan(image_data.max())
        min_lum = 0 if not (has_nan and cmap) else 1  # index 0 reserved for NaN-s

        (image_data, greyscale_data_min, greyscale_data_max) = _normalize_2D(image_data, norm, vmin, vmax, min_lum)
        height, width = image_data.shape
        has_nan = numpy.isnan(image_data.max())

        cmap_list = isinstance(cmap, list)
        cmap_palettable = isinstance(cmap, str)
        if cmap and not cmap_list and not cmap_palettable:
            raise ValueError(
                "cmap must be a string (colormap name) or a list of colors, but was {}".format(type(cmap).__name__))

        if cmap:
            alpha_ch_val = 255 if alpha is None else 255 * alpha

            if cmap_palettable:
                # colormap via palettable
                if not palettable:
                    raise ValueError(
                        "Can't process `cmap`: please install 'Palettable' (https://pypi.org/project/palettable/) to your "
                        "Python environment. "
                    )

                # prepare palette
                if not has_nan:
                    cmap_256 = palettable.get_map(cmap + "_256")
                    palette = [_parse_hex_color(c, alpha_ch_val) for c in cmap_256.hex_colors]
                else:
                    cmap_255 = palettable.get_map(cmap + "_255")
                    # transparent color at index 0
                    palette = [numpy.array([0, 0, 0, 0], dtype=numpy.uint8)] \
                              + [_parse_hex_color(c, alpha_ch_val) for c in cmap_255.hex_colors]
            else:
                # custom color list - build expanded palette with quantization
                n_colors = len(cmap)
                if n_colors == 0:
                    raise ValueError("cmap list must contain at least one color")

                if not has_nan:
                    # 256 entries: values 0-255 map to indices 0-255
                    palette = [_parse_color(cmap[min(i * n_colors // 256, n_colors - 1)], alpha_ch_val)
                               for i in range(256)]
                else:
                    # transparent at index 0, then 255 entries for values 1-255
                    palette = [numpy.array([0, 0, 0, 0], dtype=numpy.uint8)] \
                              + [_parse_color(cmap[min(i * n_colors // 255, n_colors - 1)], alpha_ch_val)
                                 for i in range(255)]

            # replace indexes with palette colors
            if has_nan:
                # replace all NaN-s with 0 (index 0 for transparent color)
                numpy.nan_to_num(image_data, copy=False, nan=0)
            image_data = numpy.take(palette, numpy.round(image_data).astype(numpy.int32), axis=0)
        else:
            # Greyscale
            alpha_ch_scaler = 1 if alpha is None else alpha
            is_nan = numpy.isnan(image_data)
            im_shape = numpy.shape(image_data)
            alpha_ch = numpy.zeros(im_shape, dtype=image_data.dtype)
            alpha_ch[is_nan == False] = 255 * alpha_ch_scaler
            image_data[is_nan] = 0
            image_data = numpy.repeat(image_data[:, :, numpy.newaxis], 3, axis=2)  # convert to RGB
            image_data = numpy.dstack((image_data, alpha_ch))  # convert to RGBA
    else:
        # Color RGB/RGBA image
        # Make a copy:
        #   - prevent modification of the original image
        #   - drop read-only flag
        image_data = numpy.copy(image_data)
        if image_data.dtype.kind == 'f':
            image_data *= 255

        height, width, nchannels = image_data.shape

        if nchannels == 3:
            alpha_ch_scaler = 1 if alpha is None else alpha
            # RGB image: add alpha channel (RGBA)
            alpha_ch = numpy.full((height, width, 1), 255 * alpha_ch_scaler, dtype=image_data.dtype)
            image_data = numpy.dstack((image_data, alpha_ch))
        elif nchannels == 4 and alpha is not None:
            # RGBA image: apply alpha scaling
            # Convert to float if needed to avoid casting errors when multiplying by alpha
            if image_data.dtype.kind != 'f':
                image_data = image_data.astype(numpy.float32)
            image_data[:, :, 3] *= alpha

    # Make sure all values are ints in range 0-255.
    image_data.clip(0, 255, out=image_data)

    # Image extent with possible axis flipping.
    # The default image bounds include 1/2 unit size expand in all directions.
    ext_x0, ext_x1, ext_y0, ext_y1 = -.5, width - .5, -.5, height - .5
    if extent:
        try:
            ext_x0, ext_x1, ext_y0, ext_y1 = [float(v) for v in extent]
        except ValueError as e:
            raise ValueError(
                "Invalid `extent`: list of 4 numbers expected: {}".format(e)
            )

    if ext_x0 > ext_x1:
        # copy after flip to work around this numpy issue: https://github.com/drj11/pypng/issues/91
        image_data = numpy.flip(image_data, axis=1).copy()
        ext_x0, ext_x1 = ext_x1, ext_x0

    if ext_y0 > ext_y1:
        image_data = numpy.flip(image_data, axis=0)
        ext_y0, ext_y1 = ext_y1, ext_y0

    # Make sure each value is 1 byte and the type is numpy.int8.
    # Otherwise, pypng will produce broken colors.
    if image_data.dtype.kind == 'f':
        # Can't cast directly from float to np.int8.
        image_data += 0.5
        image_data = image_data.astype(numpy.int16)

    if image_data.dtype != numpy.int8:
        image_data = image_data.astype(numpy.int8)

    # Reshape to 2d-array:
    image_2d = image_data.reshape(-1, width * 4)  # always 4 channels (RGBA)

    # PNG writer
    png_bytes = io.BytesIO()
    png.Writer(
        width=width,
        height=height,
        greyscale=False,
        alpha=True,
        bitdepth=8,
        compression=compression
    ).write(png_bytes, image_2d)

    href = 'data:image/png;base64,' + str(base64.standard_b64encode(png_bytes.getvalue()), 'utf-8')

    # The Legend (colorbar)
    show_legend = as_boolean(show_legend, default=True)
    normalize = as_boolean(norm, default=True)
    color_scale = None
    color_scale_mapping = None
    if greyscale and show_legend:
        # aes(color=[greyscale_data_min, greyscale_data_max])
        color_scale_mapping = aes(**{color_by: [greyscale_data_min, greyscale_data_max]})
        if cmap_palettable and normalize:
            cmap_32 = palettable.get_map(cmap + "_32")
            color_scale = scale_gradientn(aesthetic=color_by, colors=cmap_32.hex_colors, name="", guide=cguide)
        elif cmap_palettable and not normalize:
            cmap_256 = palettable.get_map(cmap + "_256")
            start = max(0, round(greyscale_data_min))
            end = min(255, round(greyscale_data_max))
            cmap_hex_colors = cmap_256.hex_colors[start:end]
            if len(cmap_hex_colors) > 32:
                # reduce the number of colors to 32
                indices = numpy.linspace(0, len(cmap_hex_colors) - 1, 32, dtype=int)
                cmap_hex_colors = [cmap_hex_colors[i] for i in indices]

            color_scale = scale_gradientn(aesthetic=color_by, colors=cmap_hex_colors, name="", guide=cguide)
        elif cmap_list:
            # custom color list - a 'binned' colorbar
            color_scale = scale_manual(aesthetic=color_by, values=cmap, name="", guide=cguide)
        else:
            start = 0 if normalize else greyscale_data_min / 255.
            end = 1 if normalize else greyscale_data_max / 255.
            color_scale = scale_grey(aesthetic=color_by, start=start, end=end, name="", guide=cguide)

    # Image geom layer
    geom_image_layer = _geom(
        'image',
        mapping=color_scale_mapping,
        href=href,
        xmin=ext_x0,
        ymin=ext_y0,
        xmax=ext_x1,
        ymax=ext_y1,
        show_legend=show_legend,
        inherit_aes=False,
        color_by=color_by if (show_legend and greyscale) else None,
    )

    if (color_scale is not None):
        geom_image_layer = geom_image_layer + color_scale

    return geom_image_layer