export async function heatmap()

in tfjs-vis/src/render/heatmap.ts [63:253]


export async function heatmap(
    container: Drawable, data: HeatmapData,
    opts: HeatmapOptions = {}): Promise<void> {
  const options = Object.assign({}, defaultOpts, opts);
  const drawArea = getDrawArea(container);

  let inputValues = data.values;
  if (options.rowMajor) {
    inputValues = await convertToRowMajor(data.values);
  }

  // Data validation
  const {xTickLabels, yTickLabels} = data;
  if (xTickLabels != null) {
    const dimension = 0;
    assertLabelsMatchShape(inputValues, xTickLabels, dimension);
  }

  // Note that we will only do a check on the first element of the second
  // dimension. We do not protect users against passing in a ragged array.
  if (yTickLabels != null) {
    const dimension = 1;
    assertLabelsMatchShape(inputValues, yTickLabels, dimension);
  }

  //
  // Format data for vega spec; an array of objects, one for for each cell
  // in the matrix.
  //
  // If custom labels are passed in for xTickLabels or yTickLabels we need
  // to make sure they are 'unique' before mapping them to visual properties.
  // We therefore append the index of the label to the datum that will be used
  // for that label in the x or y axis. We could do this in all cases but choose
  // not to to avoid unnecessary string operations.
  //
  // We use IDX_SEPARATOR to demarcate the added index
  const IDX_SEPARATOR = '@tfidx@';

  const values: MatrixEntry[] = [];
  if (inputValues instanceof tf.Tensor) {
    assert(
        inputValues.rank === 2,
        'Input to renderHeatmap must be a 2d array or Tensor2d');

    // This is a slightly specialized version of TensorBuffer.get, inlining it
    // avoids the overhead of a function call per data element access and is
    // specialized to only deal with the 2d case.
    const inputArray = await inputValues.data();
    const [numRows, numCols] = inputValues.shape;

    for (let row = 0; row < numRows; row++) {
      const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row;
      for (let col = 0; col < numCols; col++) {
        const y =
            yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col;

        const index = (row * numCols) + col;
        const value = inputArray[index];

        values.push({x, y, value});
      }
    }
  } else {
    const inputArray = inputValues;
    for (let row = 0; row < inputArray.length; row++) {
      const x = xTickLabels ? `${xTickLabels[row]}${IDX_SEPARATOR}${row}` : row;
      for (let col = 0; col < inputArray[row].length; col++) {
        const y =
            yTickLabels ? `${yTickLabels[col]}${IDX_SEPARATOR}${col}` : col;
        const value = inputArray[row][col];
        values.push({x, y, value});
      }
    }
  }

  const embedOpts = {
    actions: false,
    mode: 'vega-lite' as Mode,
    defaultStyle: false,
  };

  const spec: VisualizationSpec = {
    'width': options.width || getDefaultWidth(drawArea),
    'height': options.height || getDefaultHeight(drawArea),
    'padding': 0,
    'autosize': {
      'type': 'fit',
      'contains': 'padding',
      'resize': true,
    },
    'config': {
      'axis': {
        'labelFontSize': options.fontSize,
        'titleFontSize': options.fontSize,
      },
      'text': {'fontSize': options.fontSize},
      'legend': {
        'labelFontSize': options.fontSize,
        'titleFontSize': options.fontSize,
      },
      'scale': {'bandPaddingInner': 0, 'bandPaddingOuter': 0},
    },
    //@ts-ignore
    'data': {'values': values},
    'mark': {'type': 'rect', 'tooltip': true},
    'encoding': {
      'x': {
        'field': 'x',
        'type': options.xType,
        'title': options.xLabel,
        'sort': false,
      },
      'y': {
        'field': 'y',
        'type': options.yType,
        'title': options.yLabel,
        'sort': false,
      },
      'fill': {
        'field': 'value',
        'type': 'quantitative',
      }
    }
  };

  //
  // Format custom labels to remove the appended indices
  //
  const suffixPattern = `${IDX_SEPARATOR}\\d+$`;
  const suffixRegex = new RegExp(suffixPattern);
  if (xTickLabels) {
    // @ts-ignore
    spec.encoding.x.axis = {
      'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`,
    };
  }

  if (yTickLabels) {
    // @ts-ignore
    spec.encoding.y.axis = {
      'labelExpr': `replace(datum.value, regexp(/${suffixPattern}/), '')`,
    };
  }

  // Customize tooltip formatting to remove the appended indices
  if (xTickLabels || yTickLabels) {
    //@ts-ignore
    embedOpts.tooltip = {
      sanitize: (value: string|number) => {
        const valueString = String(value);
        return valueString.replace(suffixRegex, '');
      }
    };
  }

  let colorRange: string[]|string;
  switch (options.colorMap) {
    case 'blues':
      colorRange = ['#f7fbff', '#4292c6'];
      break;
    case 'greyscale':
      colorRange = ['#000000', '#ffffff'];
      break;
    case 'viridis':
    default:
      colorRange = 'viridis';
      break;
  }

  if (colorRange !== 'viridis') {
    //@ts-ignore
    const fill = spec.encoding.fill;
    // @ts-ignore
    fill.scale = {'range': colorRange};
  }

  if (options.domain) {
    //@ts-ignore
    const fill = spec.encoding.fill;
    // @ts-ignore
    if (fill.scale != null) {
      // @ts-ignore
      fill.scale = Object.assign({}, fill.scale, {'domain': options.domain});
    } else {
      // @ts-ignore
      fill.scale = {'domain': options.domain};
    }
  }

  await embed(drawArea, spec, embedOpts);
}