python-package/lets_plot/bistro/corr.py (171 lines of code) (raw):
# Copyright (c) 2020. JetBrains s.r.o.
# Use of this source code is governed by the MIT license that can be found in the LICENSE file.
#
"""Correlation matrix implementation module"""
from typing import Any
from lets_plot.plot.util import is_pandas_data_frame
try:
import numpy
except ImportError:
numpy = None
try:
import pandas
except ImportError:
pandas = None
from lets_plot.plot.core import PlotSpec
__all__ = ['corr_plot']
def _is_corr_matrix(data: Any):
if is_pandas_data_frame(data):
if data.shape[0] != data.shape[1]:
return False
if not (all(col_type == 'float64' for col_type in data.dtypes)):
return False
for column in data:
import math
if not all(math.isnan(v) or (1.0 >= v >= -1.0) for v in data[column]):
return False
return True
elif isinstance(data, dict):
m = len(data.keys())
for column in data.values():
if not isinstance(column, (list, tuple)):
return False
if len(column) != m:
return False
import math
for v in column:
if not isinstance(v, float):
return False
if math.isnan(v):
return True
if 1.0 >= v >= -1.0:
return True
return False
return True
else:
return False
class corr_plot:
"""
This class is intended to build correlation matrix plot.
Notes
-----
To hide axis tooltips, set 'blank' or the result of `element_blank() <https://lets-plot.org/python/pages/api/lets_plot.element_blank.html>`__
to the ``axis_tooltip``, ``axis_tooltip_x`` or ``axis_tooltip_y`` parameter of the `theme() <https://lets-plot.org/python/pages/api/lets_plot.theme.html>`__.
Examples
--------
.. jupyter-execute::
:linenos:
:emphasize-lines: 7
import numpy as np
from lets_plot import *
from lets_plot.bistro.corr import *
LetsPlot.setup_html()
np.random.seed(42)
data = {var: np.random.poisson(size=10) for var in 'abcdef'}
corr_plot(data).tiles().build()
|
.. jupyter-execute::
:linenos:
:emphasize-lines: 7-9
import numpy as np
from lets_plot import *
from lets_plot.bistro.corr import *
LetsPlot.setup_html()
np.random.seed(42)
data = {var: np.random.uniform(size=10) for var in 'abcd'}
corr_plot(data).tiles(type='upper', diag=True)\\
.labels(type='upper', diag=True, map_size=True, color='black')\\
.palette_RdBu().build()
|
.. jupyter-execute::
:linenos:
:emphasize-lines: 7-9
import numpy as np
from lets_plot import *
from lets_plot.bistro.corr import *
LetsPlot.setup_html()
np.random.seed(42)
data = {var: np.random.normal(size=10) for var in 'abcdef'}
corr_plot(data, flip=False, threshold=.4).points().labels()\\
.palette_gradient(low='#d73027', mid='#ffffbf', high='#1a9850')\\
.build()
"""
def _duplicate(self):
dup = corr_plot(
data=self._data,
show_legend=self._show_legend,
flip=self._reverse_y,
threshold=self.threshold
)
dup._color_scale = self._color_scale
dup._fill_scale = self._fill_scale
dup._points_params = self._points_params
dup._tiles_params = self._tiles_params
dup._labels_params = self._labels_params
dup._labels_map_size = self._labels_map_size
dup._palette = self._palette
dup._low = self._low
dup._mid = self._mid
dup._high = self._high
return dup
def __init__(self, data, show_legend=True, flip=True, threshold=None):
"""
Parameters
----------
data : dict or Pandas or Polars ``DataFrame``
Correlation matrix or data (correlation will be calculated for each variable pair).
Data will be recognized as correlation matrix if it has a square shape and all values are
in range -1.0..+1.0 or NaN.
show_legend : bool, default=True
If True legend is shown.
flip : bool, default=True
If True the y axis is flipped.
threshold : float, default=0.0
Minimal correlation abs value to be included in result.
Accept values between 0 and 1.
"""
self._data = data
self._show_legend = show_legend
self._reverse_y = flip if flip else False
self.threshold = threshold
self._color_scale = None
self._fill_scale = None
self._points_params = None
self._tiles_params = None
self._labels_params = None
self._labels_map_size = None
self._palette = None
self._low = None
self._mid = None
self._high = None
def points(self, type=None, diag=None):
"""
Method defines correlation matrix layer drawn by points to the plot.
Parameters
----------
type : {'upper', 'lower', 'full'}
Type of matrix. Default - contextual.
diag : bool
Determines whether to fill the main diagonal with values or not.
Default - contextual.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_points(type, diag)
def labels(self, type=None, diag=None, map_size=None, color=None):
"""
Method defines correlation matrix layer drawn with geom_text to the plot.
Parameters
----------
type : {'upper', 'lower', 'full'}
Type of matrix. Default - contextual.
diag : bool
Determines whether to fill the main diagonal with values or not.
Default - contextual.
map_size : bool
If True, then absolute value of correlation is mapped to text size.
If False - the text size is constant. Default - contextual.
color : str
Set text color.
For more info see `Color and Fill <https://lets-plot.org/python/pages/aesthetics.html#color-and-fill>`__.
Default - contextual.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_labels(type, diag, map_size, color)
def tiles(self, type=None, diag=None):
"""
Method defines correlation matrix layer drawn as square tiles to the plot.
Parameters
----------
type : {'upper', 'lower', 'full'}
Type of matrix. Default - contextual.
diag : bool
Determines whether to fill the main diagonal with values or not.
Default - contextual.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_tiles(type, diag)
def palette_gradient(self, low, mid, high):
"""
Set `scale_color_gradient2() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_gradient2.html>`__ and
`scale_fill_gradient() <https://lets-plot.org/python/pages/api/lets_plot.scale_fill_gradient.html>`__ for corr plot.
Parameters
----------
low : str
Color for low end of gradient (correlation -1).
mid : str
Color for mid point (correlation 0).
high : str
Color for high end of gradient (correlation 1).
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_gradient_palette(low, mid, high)
def palette_BrBG(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with BrBG palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._set_brewer_palette('BrBG')
def palette_PiYG(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with PiYG palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('PiYG')
def palette_PRGn(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with PRGn palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('PRGn')
def palette_PuOr(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with PuOr palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('PuOr')
def palette_RdBu(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with RdBu palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('RdBu')
def palette_RdGy(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with RdGy palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('RdGy')
def palette_RdYlBu(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with RdYlBu palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('RdYlBu')
def palette_RdYlGn(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with RdYlGn palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('RdYlGn')
def palette_Spectral(self):
"""
Set `scale_color_brewer() <https://lets-plot.org/python/pages/api/lets_plot.scale_color_brewer.html>`__ with Spectral palette for corr plot.
Returns
-------
``corr_plot``
Correlation plot specification.
"""
return self._duplicate()._set_brewer_palette('Spectral')
def _set_points(self, type=None, diag=None):
self._points_params = {'type': type, 'diag': diag, 'threshold': self.threshold}
return self
def _set_labels(self, type=None, diag=None, map_size=None, color=None):
self._labels_params = {'type': type, 'diag': diag, 'color': color, 'threshold': self.threshold}
self._labels_map_size = map_size
return self
def _set_tiles(self, type=None, diag=None):
self._tiles_params = {'type': type, 'diag': diag, 'threshold': self.threshold}
return self
def _set_gradient_palette(self, low, mid, high):
self._palette = 'gradient'
self._low = low
self._mid = mid
self._high = high
return self
def _set_brewer_palette(self, palette):
self._palette = palette
self._low = None
self._mid = None
self._high = None
return self
def build(self) -> PlotSpec:
"""
This method creates PlotSpec object.
Returns
-------
``PlotSpec``
Plot specification.
"""
if self._points_params is not None:
point_params = {
'type': self._points_params['type'],
'diag': self._points_params['diag']
}
else:
point_params = None
if self._labels_params is not None:
label_params = {
'type': self._labels_params['type'],
'diag': self._labels_params['diag'],
'color': self._labels_params['color'],
'map_size': self._labels_map_size
}
else:
label_params = None
if self._tiles_params is not None:
tile_params = {
'type': self._tiles_params['type'],
'diag': self._tiles_params['diag']
}
else:
tile_params = None
data = self._data
if _is_corr_matrix(data):
coefficients = True
else:
if is_pandas_data_frame(data):
data = data.corr(numeric_only=True)
coefficients = True
else:
coefficients = False
return PlotSpec(data=data, mapping=None, scales=[], layers=[], bistro={
'name': 'corr',
'coefficients': coefficients,
'show_legend': self._show_legend,
'flip': self._reverse_y,
'threshold': self.threshold,
'palette': self._palette,
'low': self._low,
'mid': self._mid,
'high': self._high,
'point_params': point_params,
'tile_params': tile_params,
'label_params': label_params
})