_draw()

in tensorboard/plugins/hparams/tf_hparams_scatter_plot_matrix_plot/tf-hparams-scatter-plot-matrix-plot.ts [188:794]


  _draw() {
    const utils = tf_hparams_utils;
    const _this = this;
    if (
      !this.sessionGroups ||
      this.sessionGroups.length == 0 ||
      !this.visibleSchema ||
      this.visibleSchema.metricInfos.length == 0
    ) {
      // If there's no metrics or session groups. There's nothing to draw.
      return;
    }
    // An array containing the visibleSchema-columns (hparams followed by
    // metrics) indices. These index the columns of the scatter plot matrix.
    const cols = d3.range(utils.numVisibleColumns(_this.visibleSchema));
    // An array containing the metric indices. These index the rows of the
    // scatter plot matrix.
    const metrics = d3.range(utils.numVisibleMetrics(_this.visibleSchema));
    // The margin in pixels from the left to leave for the y-axis text
    // (tick values and x-axis label).
    const yAxisTextMargin = 80;
    // The margin in pixels from the bottom to leave for the x-axis text
    // (tick values and x-axis label).
    const xAxisTextMargin = 50;
    // Each cell in the scatter plot matrix has a rectangular frame.
    // The margin in pixels to use between the cell boundary and its frame.
    const frameMargin = 5;
    // cellX(col), cellY(metric) are the svg coordinates of the upper
    // left corner of the boundary of the cell indexed by (col, metric).
    const cellX = d3
      .scaleBand()
      .domain(cols as any)
      .range([yAxisTextMargin + frameMargin, this.width - 1 - frameMargin])
      .paddingInner(0.1);
    const cellY = d3
      .scaleBand()
      .domain(metrics as any)
      .range([this.height - 1 - frameMargin - xAxisTextMargin, frameMargin])
      .paddingInner(0.1);
    const cellWidth = cellX.bandwidth();
    const cellHeight = cellY.bandwidth();
    // xCoords[col](colValue), yCoords[metric](metricValue) are the
    // coordinates of the marker representing the data
    // (colValue, metricValue) in the cell indexed by (col, metric).
    // The coordinates are relative to the cell boundary's upper
    // left corner.
    const xCoords = cols.map((c) => _this._cellScale(c, [0, cellWidth - 1]));
    const yCoords = metrics.map((m) =>
      _this._cellScale(m + utils.numVisibleHParams(_this.visibleSchema), [
        cellHeight - 1,
        0,
      ])
    );
    // ---------------------------------------------------------------------
    // Draw axes.
    // ---------------------------------------------------------------------
    // X-Axes and labels.
    const xAxesG = this._svg
      .selectAll('.x-axis')
      .data(cols)
      .enter()
      .append('g')
      .classed('x-axis', true)
      .attr('transform', (col) => utils.translateStr(cellX(col), 0));
    function xAxisClipPathId(col) {
      return 'x-axis-clip-path-' + col;
    }
    function xLabelClipPathId(col) {
      return 'x-label-clip-path-' + col;
    }
    xAxesG
      .append('clipPath')
      .attr('id', xAxisClipPathId)
      .append('rect')
      .attr('x', -frameMargin)
      .attr('y', 0)
      .attr('width', cellWidth + 2 * frameMargin)
      .attr('height', _this.height - xAxisTextMargin / 2);
    xAxesG
      .append('clipPath')
      .attr('id', xLabelClipPathId)
      .append('rect')
      .attr('x', 0)
      .attr('y', _this.height - xAxisTextMargin / 2)
      .attr('width', cellWidth)
      .attr('height', xAxisTextMargin / 2);
    xAxesG
      .append('g')
      .attr('clip-path', (col) => 'url(#' + xAxisClipPathId(col) + ')')
      .each(function (col) {
        d3.select(this).call(
          drawAxis,
          d3
            .axisBottom(xCoords[col] as any)
            .tickSize(_this.height - xAxisTextMargin),
          cellWidth,
          /* minLabelSize */ 40,
          _this.options.columns[col].scale
        );
      });
    // Draw a label for each axis.
    xAxesG
      .append('g')
      .classed('x-axis-label', true)
      .attr('clip-path', (col) => 'url(#' + xLabelClipPathId(col) + ')')
      .append('text')
      .attr('text-anchor', 'middle')
      .attr('x', cellWidth / 2)
      .attr('y', _this.height - 1 - xAxisTextMargin / 4)
      .text((col) => utils.schemaVisibleColumnName(_this.visibleSchema, col))
      .append('title') // Show full name as a tooltip.
      .text((col) => utils.schemaVisibleColumnName(_this.visibleSchema, col));
    // Y-Axes and labels.
    const yAxesG = this._svg
      .selectAll('.y-axis')
      .data(metrics)
      .enter()
      .append('g')
      .classed('y-axis', true)
      .attr('transform', (metric) =>
        utils.translateStr(_this.width - 1, cellY(metric))
      );
    function yAxisClipPathId(metric) {
      return 'y-axis-clip-path-' + metric;
    }
    function yLabelClipPathId(metric) {
      return 'y-label-clip-path-' + metric;
    }
    yAxesG
      .append('clipPath')
      .attr('id', yAxisClipPathId)
      .append('rect')
      .attr('x', -(_this.width - yAxisTextMargin / 2 - 1))
      .attr('y', -frameMargin)
      .attr('width', _this.width - yAxisTextMargin / 2)
      .attr('height', cellHeight + 2 * frameMargin);
    yAxesG
      .append('clipPath')
      .attr('id', yLabelClipPathId)
      .append('rect')
      .attr('x', -(_this.width - 1))
      .attr('y', 0)
      .attr('width', yAxisTextMargin / 2)
      .attr('height', cellHeight);
    yAxesG
      .append('g')
      .attr('clip-path', (metric) => 'url(#' + yAxisClipPathId(metric) + ')')
      .each(function (metric) {
        d3.select(this).call(
          drawAxis,
          d3
            .axisLeft(yCoords[metric] as any)
            .tickSize(_this.width - yAxisTextMargin),
          cellHeight,
          /* minLabelSize */ 20,
          _this.options.columns[
            metric + utils.numVisibleHParams(_this.visibleSchema)
          ].scale
        );
      });
    // Append a label for each axis.
    yAxesG
      .append('g')
      .classed('y-axis-label', true)
      .attr('clip-path', (metric) => 'url(#' + yLabelClipPathId(metric) + ')')
      .append('text')
      .attr('text-anchor', 'middle')
      .attr('x', -(_this.width - yAxisTextMargin / 4 - 1))
      .attr('y', cellHeight / 2)
      .attr(
        'transform',
        utils.rotateStr(
          90,
          -(_this.width - yAxisTextMargin / 4 - 1),
          cellHeight / 2
        )
      )
      .text((metric) =>
        utils.metricName(_this.visibleSchema.metricInfos[metric])
      )
      .append('title') // Show full name as a tooltip.
      .text((metric) =>
        utils.metricName(_this.visibleSchema.metricInfos[metric])
      );
    function drawAxis(g, axisGen, axisLength, minLabelSize, scaleType) {
      // We compute the number of ticks to display based on the estimate
      // of the minimum size to allow for a label.
      const numTicks = Math.floor(axisLength / minLabelSize);
      const scale = axisGen.scale();
      if (scaleType === 'QUANTILE') {
        // The default tickValues of a quantile scale is just the scale
        // domain, which produces overlapping labels if the number of
        // elements in the domain is greater than the number of
        // quantiles.
        let quantiles = scale.quantiles();
        const step = Math.ceil(quantiles.length / numTicks);
        quantiles = d3
          .range(0, quantiles.length, step)
          .map((i) => quantiles[i]);
        axisGen.tickValues(quantiles).tickFormat(d3.format('-.2g'));
      }
      if (scaleType === 'LINEAR' || scaleType === 'LOG') {
        // The following is equivalent to: axisGen.ticks(numTicks). We
        // use the form below, since otherwise the closure compiler
        // erroneously drops the parameter 'numTicks' from the call. It does
        // this, since d3 defines the variadic 'ticks' method as
        // function(), which closure regards as a function that takes no
        // parameters.
        axisGen['ticks'](numTicks);
      }
      g.call(axisGen);
      // Remove the actual axis line, and grey out the tick lines.
      g.selectAll('.domain').remove();
      g.selectAll('.tick line').attr('stroke', '#ddd');
    }
    // ---------------------------------------------------------------------
    // Draw cell frames.
    // ---------------------------------------------------------------------
    const cells = this._svg
      .selectAll('.cell')
      .data(d3.cross(cols, metrics))
      .enter()
      .append('g')
      .classed('cell', true)
      .attr('transform', ([col, metric]) =>
        utils.translateStr(cellX(col), cellY(metric))
      );
    const frames = cells
      .append('g')
      .classed('frame', true)
      .append('rect')
      .attr('x', -frameMargin)
      .attr('y', -frameMargin)
      .attr('width', cellWidth + 2 * frameMargin)
      .attr('height', cellHeight + 2 * frameMargin)
      .attr('stroke', '#000')
      .attr('fill', 'none')
      .attr('shape-rendering', 'crispEdges');
    // ---------------------------------------------------------------------
    // Draw data point markers.
    // ---------------------------------------------------------------------
    let colorScale = null;
    if (_this.options.colorByColumnIndex !== undefined) {
      colorScale = d3
        .scaleLinear()
        .domain(this._colExtent(this.options.colorByColumnIndex) as any)
        .range([this.options.minColor, this.options.maxColor])
        .interpolate(d3.interpolateLab as any);
    }
    // A function mapping a sessionGroup to its marker's color.
    const markerColorFn =
      _this.options.colorByColumnIndex === undefined
        ? /* Use default color if no color-by column is selected. */
          () => 'red'
        : ({sessionGroup}) =>
            colorScale(
              this._colValue(sessionGroup, _this.options.colorByColumnIndex)
            );
    // Returns the x coordinate for the marker representing sessionGroup
    // in a cell in the scatter plot matrix column indexed by 'col'.
    function markerX(sessionGroup, col) {
      return xCoords[col](_this._colValue(sessionGroup, col));
    }
    // Returns the y coordinate for the marker representing sessionGroup
    // in a cell in the scatter plot matrix row indexed by 'metric'.
    function markerY(sessionGroup, metric) {
      return yCoords[metric](_this._metricValue(sessionGroup, metric));
    }
    // A function that gets a selection of <g> elements--each should be
    // a child node of a cell <g> element (a memeber of the 'cells'
    // selection) and draws markers representing the data points in each
    // cell. The parameter 'fill' is either a constant specifying the
    // 'fill' attribute of each marker or a function taking a session
    // group that returns the fill attribute that should be set for the
    // marker representing the given session group. The function returns
    // a 3-tuple of [markers, cellMarkers, sessionGroupMarkersMap],
    // where: markers is the d3-selection of the markers,
    // cellMarkers is a 2-D array whose [col][metric] entry has a
    // d3-selection containing the markers in the [col, metric] cell,
    // and sessionGroupMarkersMap is a Map mapping a sessionGroup to the
    // array of marker HTML elements representing that sessionGroup.
    function addMarkers(cellsGSelection, fill) {
      const markers = cellsGSelection
        .selectAll('.data-marker')
        .data(([col, metric]) =>
          // Filter out session groups that don't have a metric-value
          // or a column-value for the current cell.
          _this.sessionGroups
            .filter(
              (sessionGroup) =>
                _this._colValue(sessionGroup, col) !== undefined &&
                _this._metricValue(sessionGroup, metric) !== undefined
            )
            .map((sessionGroup) => ({
              col: col,
              metric: metric,
              sessionGroup: sessionGroup,
              x: markerX(sessionGroup, col),
              y: markerY(sessionGroup, metric),
              // This will be populated by the code below with
              // a Set of all the markers representing this session
              // group.
              sessionGroupMarkers: null,
            }))
        )
        .enter()
        .append('circle')
        .classed('data-marker', true)
        .attr('cx', ({x}) => x)
        .attr('cy', ({y}) => y)
        .attr('r', 2)
        .attr('fill', fill);
      const sessionGroupMarkersMap = new Map<any, any[]>();
      _this.sessionGroups.forEach((sessionGroup) => {
        sessionGroupMarkersMap.set(sessionGroup, []);
      });
      markers.each(function (d) {
        sessionGroupMarkersMap.get(d.sessionGroup).push(this);
      });
      markers.each((d) => {
        const sessionGroupMarkers = sessionGroupMarkersMap.get(d.sessionGroup);
        d.sessionGroupMarkers = new Set(sessionGroupMarkers);
      });
      const cellMarkers = cols.map((col) =>
        metrics.map((metric) =>
          markers.filter((d) => d.col == col && d.metric == metric)
        )
      );
      return [markers, cellMarkers, sessionGroupMarkersMap];
    }
    const [markers, cellMarkers, sessionGroupMarkersMap] = addMarkers(
      cells.append('g'),
      /* fill */ markerColorFn
    );
    // ---------------------------------------------------------------------
    // Create a brush for each cell. Brushing a cell makes "visible" only
    // the markers associated with session groups whose markers
    // in the brushed cell lie within the brush selection. By "visibile"
    // here, we man colored according to color-by column. Markers that
    // are not "visibile" will be shown as grayed out.
    // ---------------------------------------------------------------------
    // For each cell, we index the markers in a quad-tree to quickly
    // find the intersection of the markers with the (brush) selection.
    // The following function creates this quad-tree for the cell indexed by
    // (col, metric). Each quad tree datum is the corresponding marker's
    // element.
    function createCellQuadTree(col, metric) {
      const data = [];
      cellMarkers[col][metric].each(function () {
        data.push(this);
      });
      return d3
        .quadtree()
        .x((elem: any) => (d3.select(elem).datum() as any).x)
        .y((elem: any) => (d3.select(elem).datum() as any).y)
        .addAll(data);
    }
    const quadTrees = cols.map((col) =>
      metrics.map((metric) => createCellQuadTree(col, metric))
    );
    // A d3-selection of the cell in 'cells' that has the active
    // brush selection, or null if the brush is not active.
    let brushedCellG = null;
    if (isBrushActive()) {
      brushedCellG = cells.filter((cellIndex) =>
        _.isEqual(cellIndex, _this._brushedCellIndex)
      );
      console.assert(brushedCellG.size() == 1, brushedCellG);
    }
    // The set of markers (in all cells) that are visible. We keep this
    // set around so that when the brush selection changes we can change
    // the "fill" attribute of only the markers we need to. This reduces
    // the browser's rendering time and makes brushing smoother.
    let visibleMarkers = new Set(markers.nodes());
    updateVisibleMarkers();
    function updateVisibleMarkers() {
      // We regard an empty (or inactive) brush selection as selecting
      // all markers.
      let newVisibleMarkers = new Set(markers.nodes());
      if (!isBrushSelectionEmpty()) {
        newVisibleMarkers = findMarkersInSelection(
          _this._brushedCellIndex,
          _this._brushSelection
        );
      }
      // Highlight the new visible markers.
      d3.selectAll(
        Array.from(
          utils.filterSet(
            newVisibleMarkers,
            (elem) => !visibleMarkers.has(elem)
          )
        ) as any
      ).attr('fill', markerColorFn);
      // Gray-out the no-longer visible markers.
      d3.selectAll(
        Array.from(
          utils.filterSet(
            visibleMarkers,
            (elem) => !newVisibleMarkers.has(elem)
          )
        ) as any
      ).attr('fill', '#ddd');
      visibleMarkers = newVisibleMarkers;
    }
    // Returns a Set of all marker elements that are in the
    // rectangle 'selection' given in coordinates relative to the
    // cell indexed by cellIndex .
    function findMarkersInSelection(cellIndex, selection) {
      console.assert(cellIndex !== null);
      console.assert(selection !== null);
      const [col, metric] = cellIndex;
      const result = new Set();
      utils.quadTreeVisitPointsInRect(
        quadTrees[col][metric],
        selection[0][0],
        selection[0][1],
        selection[1][0],
        selection[1][1],
        (elem) => {
          const data = d3.select(elem).datum() as any;
          data.sessionGroupMarkers.forEach((sg_elem) => {
            result.add(sg_elem);
          });
        }
      );
      return result;
    }
    const brush = d3
      .brush()
      .extent([
        [-frameMargin + 1, -frameMargin + 1],
        [cellWidth - 1 + frameMargin - 1, cellHeight - 1 + frameMargin - 1],
      ])
      .on('start', function () {
        if (isBrushActive() && brushedCellG.node() != this) {
          // The brush is active in a different cell.
          // Clear the selection first.
          // This will recursively call the 'start', 'brush', and
          // 'end' event listeners for the cell with the selection
          // and will update the markers. The 'if' above
          // prevents infinite recursion.
          brush.move(brushedCellG, null);
        }
        brushChanged(this);
      })
      .on('brush', function () {
        brushChanged(this);
      })
      .on('end', function () {
        brushChanged(this);
      });
    // Updates the internal state in response to a brush event in
    // the cell whose <g> element (in 'cells') is given by cellGNode
    function brushChanged(cellGNode) {
      // For some reason the closure compiler drops the argument when we
      // write the call below as 'd3.brushSelection(cellGNode)'.
      const brushSelection = d3['brushSelection'](cellGNode);
      if (
        (!isBrushActive() && brushSelection === null) ||
        (isBrushActive() &&
          cellGNode === brushedCellG.node() &&
          _.isEqual(brushSelection, _this._brushSelection))
      ) {
        // Nothing to do if selection hasn't changed.
        return;
      }
      _this._brushSelection = brushSelection;
      if (brushSelection !== null) {
        brushedCellG = d3.select(cellGNode);
        _this._brushedCellIndex = brushedCellG.datum();
      } else {
        brushedCellG = null;
        _this._brushedCellIndex = null;
      }
      updateVisibleMarkers();
    }
    function isBrushActive() {
      return _this._brushedCellIndex !== null && _this._brushSelection !== null;
    }
    function isBrushSelectionEmpty() {
      return (
        !isBrushActive() ||
        _this._brushSelection[0][0] === _this._brushSelection[1][0] ||
        _this._brushSelection[0][1] === _this._brushSelection[1][1]
      );
    }
    // Render the brush elements on each cell.
    cells.call(brush);
    if (isBrushActive()) {
      // Set the internal brush selection to what it was before
      // the 'redraw()'.
      brush.move(brushedCellG, _this._brushSelection as any);
    }
    // ---------------------------------------------------------------------
    // Add event listeners for highlighting the session group whose markers
    // are closest to the mouse pointer (only "visible" session groups
    // are considered -- see brushing above). Also, add event listeners
    // for making the highlighted session group the currently-selected
    // group by clicking.
    // ---------------------------------------------------------------------
    // A d3-selection containing the nodes in markers representing the
    // SessionGroup with a marker closest to the mouse pointer or null
    // if the distance to the closest session group is greater than a
    // threshold. This won't get set until the first mouse movement over
    // a cell.
    let closestMarkers = null;
    // A d3-selection containing the nodes in markers representing the
    // markers of the currently selected session group or null if no
    // session group is selected.
    let selectedMarkers = null;
    if (this.selectedSessionGroup !== null) {
      selectedMarkers = d3
        .selectAll(sessionGroupMarkersMap.get(this.selectedSessionGroup))
        .classed('selected-marker', true);
    }
    cells
      .on('click', function () {
        const newSelectedMarkers =
          closestMarkers === selectedMarkers ? null : closestMarkers;
        if (newSelectedMarkers === selectedMarkers) {
          return;
        }
        if (selectedMarkers !== null) {
          selectedMarkers.classed('selected-marker', false);
        }
        selectedMarkers = newSelectedMarkers;
        if (selectedMarkers !== null) {
          selectedMarkers.classed('selected-marker', true);
        }
        const newSessionGroup =
          selectedMarkers === null
            ? null
            : // All elements in selectedMarkers should have the same
              // sessionGroup.
              selectedMarkers.datum().sessionGroup;
        _this.selectedSessionGroup = newSessionGroup;
      })
      .on('mousemove mouseenter', function ([col, metric]) {
        const [x, y] = d3.mouse(this);
        const newClosestMarkers = findClosestMarkers(
          col,
          metric,
          x,
          y,
          /* threshold */ 20
        );
        if (closestMarkers === newClosestMarkers) {
          return;
        }
        if (closestMarkers !== null) {
          closestMarkers.classed('closest-marker', false);
        }
        closestMarkers = newClosestMarkers;
        if (closestMarkers !== null) {
          closestMarkers.classed('closest-marker', true);
          // All elements in closestMarkers should have the same
          // sessionGroup.
          _this.closestSessionGroup = closestMarkers.datum().sessionGroup;
        } else {
          _this.closestSessionGroup = null;
        }
      })
      .on('mouseleave', function ([col, metric]) {
        if (closestMarkers !== null) {
          closestMarkers.classed('closest-marker', false);
          closestMarkers = null;
          _this.closestSessionGroup = null;
        }
      });
    // Finds a closest visible marker in the [col,metric] cell to the point
    // with cell-relative coordinates (x,y). If that point's distance
    // to the point at (x,y) is larger than threshold, returns null;
    // otherwise returns the d3-selection consisting of the markers
    // representing the session group of that closest marker.
    function findClosestMarkers(metric, col, x, y, threshold) {
      let minDist = Infinity;
      let minSessionGroup = null;
      utils.quadTreeVisitPointsInDisk(
        quadTrees[metric][col],
        x,
        y,
        threshold,
        (elem, distanceToCenter) => {
          if (visibleMarkers.has(elem) && distanceToCenter < minDist) {
            const data = d3.select(elem).datum() as any;
            minDist = distanceToCenter;
            minSessionGroup = data.sessionGroup;
          }
        }
      );
      if (minSessionGroup === null) {
        return null;
      }
      return d3.selectAll(sessionGroupMarkersMap.get(minSessionGroup));
    }
    // ---------------------------------------------------------------------
    // Polymer adds an extra ".tf-hparams-scatter-plot-matrix-plot" class to
    // each rule selector in the <style> section written above. When
    // polymer stamps a template it adds this class to every element
    // stamped; since we're injecting our own elements here, we add this
    // class to each element so that the style rules defined above will
    // apply.
    this._svg
      .selectAll('*')
      .classed('tf-hparams-scatter-plot-matrix-plot', true);
  }