#
# Copyright (c) 2019. JetBrains s.r.o.
# Use of this source code is governed by the MIT license that can be found in the LICENSE file.
#
import base64
import io

from .core import aes
from .geom import _geom
from .scale import scale_gradientn
from .scale import scale_grey
from .scale import scale_manual
from .util import as_boolean
from .._type_utils import is_ndarray

try:
    import png
except ImportError:
    png = None

try:
    import numpy
except ImportError:
    numpy = None

try:
    from palettable.matplotlib import matplotlib as palettable
except ImportError:
    palettable = None

__all__ = ['geom_imshow']


def _parse_hex_color(hex_c, alpha=None) -> 'numpy.ndarray':
    """
    Parse hex color format ('#RRGGBB' or '#RGB') to a numpy array.
    """
    hex_s = hex_c.lstrip('#')
    if len(hex_s) == 3:
        # Short format: #RGB -> #RRGGBB
        hex_s = ''.join(c + c for c in hex_s)
    list_rgb = [int(hex_s[i:i + 2], 16) for i in (0, 2, 4)]
    if alpha is not None:
        list_rgb.append(int(alpha + 0.5))
    return numpy.array(list_rgb, dtype=numpy.uint8)


def _parse_rgb_rgba_color(color_string, alpha=None) -> 'numpy.ndarray':
    """
    Parse 'rgb(r, g, b)' or 'rgba(r, g, b, a)' format to a numpy array.

    All components are expected to be integers in 0-255 range.
    The alpha parameter (0-255) is multiplied with rgba's alpha if present.
    """
    if color_string.startswith('rgba('):
        inner = color_string[5:-1]
    else:
        inner = color_string[4:-1]

    parts = [x.strip() for x in inner.split(',')]

    if len(parts) == 3:
        r, g, b = [int(x) for x in parts]
        result = [r, g, b]
        if alpha is not None:
            result.append(int(alpha + 0.5))
    elif len(parts) == 4:
        r, g, b, color_alpha = [int(x) for x in parts]
        if alpha is not None:
            # Multiply both alphas: color_alpha (0-255) * alpha (0-255) / 255
            result = [r, g, b, int(color_alpha * alpha / 255 + 0.5)]
        else:
            result = [r, g, b, color_alpha]
    else:
        raise ValueError("Invalid rgb/rgba format: {}".format(color_string))

    return numpy.array(result, dtype=numpy.uint8)


def _parse_color(color_string, alpha=None) -> 'numpy.ndarray':
    """
    Parse a color string (hex, rgb(), or rgba()) to an RGBA numpy array.

    Parameters
    ----------
    color_string : str
        Color in one of the formats:
        - '#RRGGBB' or '#RGB' (hex)
        - 'rgb(r, g, b)' with int components (0-255)
        - 'rgba(r, g, b, a)' with int components (0-255)
    alpha : optional
        Alpha value (0-255) to apply. Multiplied with existing alpha for rgba.

    Returns
    -------
    numpy.ndarray
        RGBA array with dtype uint8.
    """
    color_string = color_string.strip()

    if color_string.startswith('#'):
        return _parse_hex_color(color_string, alpha)
    elif color_string.startswith('rgb'):
        return _parse_rgb_rgba_color(color_string, alpha)
    else:
        raise ValueError("Unsupported color format: {}. Expected hex (#RRGGBB), rgb(), or rgba()".format(color_string))


def _normalize_2D(image_data, norm, vmin, vmax, min_lum):
    """
    Take a numpy 2D array of floats or ints and
    return a 2D array of ints with the target range [0..255].
    Values outside the target range will be clipped later.
    """
    min_lum = max(0, min_lum)
    max_lum = 255 - min_lum

    vmin = float(vmin if vmin is not None else numpy.nanmin(image_data))
    vmax = float(vmax if vmax is not None else numpy.nanmax(image_data))
    if vmin > vmax:
        raise ValueError("vmin value must be less than vmax value, was: {} > {}".format(vmin, vmax))

    normalize = as_boolean(norm, default=True)

    # Make a copy via `numpy.copy()` or via `arr.astype()`
    #   - prevent modification of the original image
    #   - work around a read-only flag in the original image

    if normalize:
        if vmin == vmax:
            image_data = numpy.copy(image_data)
            image_data[True] = 127
        else:
            # float array for scaling
            if image_data.dtype.kind == 'f':
                image_data = numpy.copy(image_data)
            else:
                image_data = image_data.astype(numpy.float32)

            image_data.clip(vmin, vmax, out=image_data)

            ratio = max_lum / (vmax - vmin)
            image_data -= vmin
            image_data *= ratio
            image_data += min_lum
    else:
        # no normalization
        image_data = numpy.copy(image_data)
        image_data.clip(min_lum, max_lum, out=image_data)
        vmin = float(numpy.nanmin(image_data))
        vmax = float(numpy.nanmax(image_data))

    return (image_data, vmin, vmax)


def geom_imshow(image_data, cmap=None, *,
                norm=None, alpha=None,
                vmin=None, vmax=None,
                extent=None,
                compression=None,
                show_legend=True,
                color_by="paint_c",
                cguide=None
                ):
    """
    Display an image specified by an ndarray with shape:

    - (M, N) - greyscale image
    - (M, N, 3) - color RGB image
    - (M, N, 4) - color RGBA image with an alpha channel

    This geom is not as flexible as `geom_raster() <https://lets-plot.org/python/pages/api/lets_plot.geom_raster.html>`__
    or `geom_tile() <https://lets-plot.org/python/pages/api/lets_plot.geom_tile.html>`__
    but vastly superior in terms of rendering efficiency.   

    Parameters
    ----------
    image_data : ndarray
        Specify image type, size, and pixel values.
        Supported array shapes are:

        - (M, N): an image with scalar data. The values are mapped to colors (greys by default) using normalization. See parameters ``norm``, ``cmap``, ``vmin``, ``vmax``.
        - (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
        - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int).

        The first two dimensions (M, N) define the rows and columns of the image.
        Out-of-range values are clipped.
    cmap : str or list, optional
        Name of colormap or a list of colors.
        If a string, it should be the name of a colormap supported by the
        Palettable package (https://github.com/jiffyclub/palettable),
        for example, "viridis", "magma", "plasma", "inferno".
        If a list, it should contain color strings in hex ('#RRGGBB'),
        'rgb(r, g, b)', or 'rgba(r, g, b, a)' format with int components (0-255).
        The greyscale values will be quantized to map to the provided colors.
        This parameter is ignored for RGB(A) images.
    norm : bool
        True (default) - luminance values in greyscale images will be scaled to [0-255] range using a linear scaler.
        False - disables scaling of luminance values in greyscale images.
        This parameter is ignored for RGB(A) images.
    alpha : float, optional
        The alpha blending value, between 0 (transparent) and 1 (opaque).
    vmin, vmax : number, optional
        Define the data range used for luminance normalization in greyscale images.
        This parameter is ignored for RGB(A) images or if parameter ``norm=False``.
    extent : list of 4 numbers: [left, right, bottom, top], optional
        Define the image's bounding box in terms of the "data coordinates".

        - ``left``, ``right``: coordinates of the pixels' outer edges along the x-axis for pixels in the 1st and the last column.
        - ``bottom``, ``top``: coordinates of the pixels' outer edges along the y-axis for pixels in the 1st and the last row.

        The default is: [-0.5, ncol-0.5, -0.5, nrow-0.5]
    compression : int, optional
        The compression level to be used by the ``zlib`` module.
        Values from 0 (no compression) to 9 (highest).
        Value None means that the ``zlib`` module uses
        the default level of compression (which is generally acceptable).
    show_legend : bool, default=True
        Greyscale images only.
        False - do not show the legend for this layer.
    color_by : {'fill', 'color', 'paint_a', 'paint_b', 'paint_c'}, default='paint_c'
        Define the color aesthetic used by the legend shown for a greyscale image.
    cguide : optional
        A result of `guide_colorbar() <https://lets-plot.org/python/pages/api/lets_plot.guide_colorbar.html>`__ call.
        Use to customize the colorbar for greyscale images.

    Returns
    -------
    ``LayerSpec``
        Geom object specification.

    Notes
    -----
    This geom doesn't understand any aesthetics.
    It doesn't support color scales either.

    Examples
    --------
    .. jupyter-execute::
        :linenos:
        :emphasize-lines: 6

        import numpy as np
        from lets_plot import *
        LetsPlot.setup_html()
        np.random.seed(42)
        image = np.random.randint(256, size=(64, 64, 4))
        ggplot() + geom_imshow(image)

    |

    .. jupyter-execute::
        :linenos:
        :emphasize-lines: 7

        import numpy as np
        from lets_plot import *
        LetsPlot.setup_html()
        n = 64
        image = 256 * np.linspace(np.linspace(0, .5, n), \\
                                  np.linspace(.5, .5, n), n)
        ggplot() + geom_imshow(image, norm=False)

    |

    .. jupyter-execute::
        :linenos:
        :emphasize-lines: 6

        import numpy as np
        from lets_plot import *
        LetsPlot.setup_html()
        np.random.seed(42)
        image = np.random.normal(size=(64, 64))
        ggplot() + geom_imshow(image, vmin=-1, vmax=1)

    """

    if png is None:
        raise ValueError("pypng is not installed")

    if not is_ndarray(image_data):
        raise ValueError("Invalid image_data: ndarray is expected but was {}".format(type(image_data)))

    if image_data.ndim not in (2, 3):
        raise ValueError(
            "Invalid image_data: 2d or 3d array is expected but was {}-dimensional".format(image_data.ndim))

    if alpha is not None:
        if not (0 <= alpha <= 1):
            raise ValueError(
                "Invalid alpha: expected float in range [0..1] but was {}".format(alpha))

    if compression is not None:
        if not (0 <= compression <= 9):
            raise ValueError(
                "Invalid compression: expected integer in range [0..9] but was {}".format(compression))

    greyscale = (image_data.ndim == 2)
    if greyscale:
        # Greyscale image

        has_nan = numpy.isnan(image_data.max())
        min_lum = 0 if not (has_nan and cmap) else 1  # index 0 reserved for NaN-s

        (image_data, greyscale_data_min, greyscale_data_max) = _normalize_2D(image_data, norm, vmin, vmax, min_lum)
        height, width = image_data.shape
        has_nan = numpy.isnan(image_data.max())

        cmap_list = isinstance(cmap, list)
        cmap_palettable = isinstance(cmap, str)
        if cmap and not cmap_list and not cmap_palettable:
            raise ValueError(
                "cmap must be a string (colormap name) or a list of colors, but was {}".format(type(cmap).__name__))

        if cmap:
            alpha_ch_val = 255 if alpha is None else 255 * alpha

            if cmap_palettable:
                # colormap via palettable
                if not palettable:
                    raise ValueError(
                        "Can't process `cmap`: please install 'Palettable' (https://pypi.org/project/palettable/) to your "
                        "Python environment. "
                    )

                # prepare palette
                if not has_nan:
                    cmap_256 = palettable.get_map(cmap + "_256")
                    palette = [_parse_hex_color(c, alpha_ch_val) for c in cmap_256.hex_colors]
                else:
                    cmap_255 = palettable.get_map(cmap + "_255")
                    # transparent color at index 0
                    palette = [numpy.array([0, 0, 0, 0], dtype=numpy.uint8)] \
                              + [_parse_hex_color(c, alpha_ch_val) for c in cmap_255.hex_colors]
            else:
                # custom color list - build expanded palette with quantization
                n_colors = len(cmap)
                if n_colors == 0:
                    raise ValueError("cmap list must contain at least one color")

                if not has_nan:
                    # 256 entries: values 0-255 map to indices 0-255
                    palette = [_parse_color(cmap[min(i * n_colors // 256, n_colors - 1)], alpha_ch_val)
                               for i in range(256)]
                else:
                    # transparent at index 0, then 255 entries for values 1-255
                    palette = [numpy.array([0, 0, 0, 0], dtype=numpy.uint8)] \
                              + [_parse_color(cmap[min(i * n_colors // 255, n_colors - 1)], alpha_ch_val)
                                 for i in range(255)]

            # replace indexes with palette colors
            if has_nan:
                # replace all NaN-s with 0 (index 0 for transparent color)
                numpy.nan_to_num(image_data, copy=False, nan=0)
            image_data = numpy.take(palette, numpy.round(image_data).astype(numpy.int32), axis=0)
        else:
            # Greyscale
            alpha_ch_scaler = 1 if alpha is None else alpha
            is_nan = numpy.isnan(image_data)
            im_shape = numpy.shape(image_data)
            alpha_ch = numpy.zeros(im_shape, dtype=image_data.dtype)
            alpha_ch[is_nan == False] = 255 * alpha_ch_scaler
            image_data[is_nan] = 0
            image_data = numpy.repeat(image_data[:, :, numpy.newaxis], 3, axis=2)  # convert to RGB
            image_data = numpy.dstack((image_data, alpha_ch))  # convert to RGBA
    else:
        # Color RGB/RGBA image
        # Make a copy:
        #   - prevent modification of the original image
        #   - drop read-only flag
        image_data = numpy.copy(image_data)
        if image_data.dtype.kind == 'f':
            image_data *= 255

        height, width, nchannels = image_data.shape

        if nchannels == 3:
            alpha_ch_scaler = 1 if alpha is None else alpha
            # RGB image: add alpha channel (RGBA)
            alpha_ch = numpy.full((height, width, 1), 255 * alpha_ch_scaler, dtype=image_data.dtype)
            image_data = numpy.dstack((image_data, alpha_ch))
        elif nchannels == 4 and alpha is not None:
            # RGBA image: apply alpha scaling
            # Convert to float if needed to avoid casting errors when multiplying by alpha
            if image_data.dtype.kind != 'f':
                image_data = image_data.astype(numpy.float32)
            image_data[:, :, 3] *= alpha

    # Make sure all values are ints in range 0-255.
    image_data.clip(0, 255, out=image_data)

    # Image extent with possible axis flipping.
    # The default image bounds include 1/2 unit size expand in all directions.
    ext_x0, ext_x1, ext_y0, ext_y1 = -.5, width - .5, -.5, height - .5
    if extent:
        try:
            ext_x0, ext_x1, ext_y0, ext_y1 = [float(v) for v in extent]
        except ValueError as e:
            raise ValueError(
                "Invalid `extent`: list of 4 numbers expected: {}".format(e)
            )

    if ext_x0 > ext_x1:
        # copy after flip to work around this numpy issue: https://github.com/drj11/pypng/issues/91
        image_data = numpy.flip(image_data, axis=1).copy()
        ext_x0, ext_x1 = ext_x1, ext_x0

    if ext_y0 > ext_y1:
        image_data = numpy.flip(image_data, axis=0)
        ext_y0, ext_y1 = ext_y1, ext_y0

    # Make sure each value is 1 byte and the type is numpy.int8.
    # Otherwise, pypng will produce broken colors.
    if image_data.dtype.kind == 'f':
        # Can't cast directly from float to np.int8.
        image_data += 0.5
        image_data = image_data.astype(numpy.int16)

    if image_data.dtype != numpy.int8:
        image_data = image_data.astype(numpy.int8)

    # Reshape to 2d-array:
    image_2d = image_data.reshape(-1, width * 4)  # always 4 channels (RGBA)

    # PNG writer
    png_bytes = io.BytesIO()
    png.Writer(
        width=width,
        height=height,
        greyscale=False,
        alpha=True,
        bitdepth=8,
        compression=compression
    ).write(png_bytes, image_2d)

    href = 'data:image/png;base64,' + str(base64.standard_b64encode(png_bytes.getvalue()), 'utf-8')

    # The Legend (colorbar)
    show_legend = as_boolean(show_legend, default=True)
    normalize = as_boolean(norm, default=True)
    color_scale = None
    color_scale_mapping = None
    if greyscale and show_legend:
        # aes(color=[greyscale_data_min, greyscale_data_max])
        color_scale_mapping = aes(**{color_by: [greyscale_data_min, greyscale_data_max]})
        if cmap_palettable and normalize:
            cmap_32 = palettable.get_map(cmap + "_32")
            color_scale = scale_gradientn(aesthetic=color_by, colors=cmap_32.hex_colors, name="", guide=cguide)
        elif cmap_palettable and not normalize:
            cmap_256 = palettable.get_map(cmap + "_256")
            start = max(0, round(greyscale_data_min))
            end = min(255, round(greyscale_data_max))
            cmap_hex_colors = cmap_256.hex_colors[start:end]
            if len(cmap_hex_colors) > 32:
                # reduce the number of colors to 32
                indices = numpy.linspace(0, len(cmap_hex_colors) - 1, 32, dtype=int)
                cmap_hex_colors = [cmap_hex_colors[i] for i in indices]

            color_scale = scale_gradientn(aesthetic=color_by, colors=cmap_hex_colors, name="", guide=cguide)
        elif cmap_list:
            # custom color list - a 'binned' colorbar
            color_scale = scale_manual(aesthetic=color_by, values=cmap, name="", guide=cguide)
        else:
            start = 0 if normalize else greyscale_data_min / 255.
            end = 1 if normalize else greyscale_data_max / 255.
            color_scale = scale_grey(aesthetic=color_by, start=start, end=end, name="", guide=cguide)

    # Image geom layer
    geom_image_layer = _geom(
        'image',
        mapping=color_scale_mapping,
        href=href,
        xmin=ext_x0,
        ymin=ext_y0,
        xmax=ext_x1,
        ymax=ext_y1,
        show_legend=show_legend,
        inherit_aes=False,
        color_by=color_by if (show_legend and greyscale) else None,
    )

    if (color_scale is not None):
        geom_image_layer = geom_image_layer + color_scale

    return geom_image_layer
