export function addHealthPill()

in tensorboard/plugins/graph/tf_graph_common/scene.ts [453:632]


export function addHealthPill(
  nodeGroupElement: SVGElement,
  healthPill: HealthPill,
  nodeInfo: render.RenderNodeInfo,
  healthPillId: number,
  healthPillWidth = 60,
  healthPillHeight = 10,
  healthPillYOffset = 0,
  textXOffset?: number
) {
  // Check if text already exists at location.
  d3.select(nodeGroupElement.parentNode as any)
    .selectAll('.health-pill')
    .remove();
  if (!healthPill) {
    return;
  }
  const lastHealthPillData = healthPill.value;
  // For now, we only visualize the 6 values that summarize counts of tensor
  // elements of various categories: -Inf, negative, 0, positive, Inf, and NaN.
  const lastHealthPillElementsBreakdown = lastHealthPillData.slice(2, 8);
  const nanCount = lastHealthPillElementsBreakdown[0];
  const negInfCount = lastHealthPillElementsBreakdown[1];
  const posInfCount = lastHealthPillElementsBreakdown[5];
  let totalCount = lastHealthPillData[1];
  const numericStats: HealthPillNumericStats = {
    min: lastHealthPillData[8],
    max: lastHealthPillData[9],
    mean: lastHealthPillData[10],
    stddev: Math.sqrt(lastHealthPillData[11]),
  };
  if (healthPillWidth == null) {
    healthPillWidth = 60;
  }
  if (healthPillHeight == null) {
    healthPillHeight = 10;
  }
  if (healthPillYOffset == null) {
    healthPillYOffset = 0;
  }
  if (nodeInfo != null && nodeInfo.node.type === NodeType.OP) {
    // Use a smaller health pill for op nodes (rendered as smaller ellipses).
    healthPillWidth /= 2;
    healthPillHeight /= 2;
  }
  let healthPillGroup = document.createElementNS(SVG_NAMESPACE, 'g');
  healthPillGroup.classList.add('health-pill');
  // Define the gradient for the health pill.
  let healthPillDefs = document.createElementNS(SVG_NAMESPACE, 'defs');
  healthPillGroup.appendChild(healthPillDefs);
  let healthPillGradient = document.createElementNS(
    SVG_NAMESPACE,
    'linearGradient'
  );
  // Every element in a web page must have a unique ID.
  const healthPillGradientId = 'health-pill-gradient-' + healthPillId;
  healthPillGradient.setAttribute('id', healthPillGradientId);
  let cumulativeCount = 0;
  let previousOffset = '0%';
  for (let i = 0; i < lastHealthPillElementsBreakdown.length; i++) {
    if (!lastHealthPillElementsBreakdown[i]) {
      // Exclude empty categories.
      continue;
    }
    cumulativeCount += lastHealthPillElementsBreakdown[i];
    // Create a color interval using 2 stop elements.
    let stopElement0 = document.createElementNS(SVG_NAMESPACE, 'stop');
    stopElement0.setAttribute('offset', previousOffset);
    stopElement0.setAttribute(
      'stop-color',
      healthPillEntries[i].background_color
    );
    healthPillGradient.appendChild(stopElement0);
    let stopElement1 = document.createElementNS(SVG_NAMESPACE, 'stop');
    let percent = (cumulativeCount * 100) / totalCount + '%';
    stopElement1.setAttribute('offset', percent);
    stopElement1.setAttribute(
      'stop-color',
      healthPillEntries[i].background_color
    );
    healthPillGradient.appendChild(stopElement1);
    previousOffset = percent;
  }
  healthPillDefs.appendChild(healthPillGradient);
  // Create the rectangle for the health pill.
  let rect = document.createElementNS(SVG_NAMESPACE, 'rect');
  rect.setAttribute('fill', 'url(#' + healthPillGradientId + ')');
  rect.setAttribute('width', String(healthPillWidth));
  rect.setAttribute('height', String(healthPillHeight));
  rect.setAttribute('y', String(healthPillYOffset));
  healthPillGroup.appendChild(rect);
  // Show a title with specific counts on hover.
  let titleSvg = document.createElementNS(SVG_NAMESPACE, 'title');
  titleSvg.textContent = _getHealthPillTextContent(
    healthPill,
    totalCount,
    lastHealthPillElementsBreakdown,
    numericStats
  );
  healthPillGroup.appendChild(titleSvg);
  // Center this health pill just right above the node for the op.
  let shouldRoundOnesDigit = false;
  if (nodeInfo != null) {
    let healthPillX = nodeInfo.x - healthPillWidth / 2;
    let healthPillY = nodeInfo.y - healthPillHeight - nodeInfo.height / 2 - 2;
    if (nodeInfo.labelOffset < 0) {
      // The label is positioned above the node. Do not occlude the label.
      healthPillY += nodeInfo.labelOffset;
    }
    healthPillGroup.setAttribute(
      'transform',
      'translate(' + healthPillX + ', ' + healthPillY + ')'
    );
    if (
      lastHealthPillElementsBreakdown[2] ||
      lastHealthPillElementsBreakdown[3] ||
      lastHealthPillElementsBreakdown[4]
    ) {
      // At least 1 "non-Inf and non-NaN" value exists (a -, 0, or + value). Show
      // stats on tensor values.
      // Determine if we should display the output range as integers.
      let node = nodeInfo.node as OpNode;
      let attributes = node.attr;
      if (attributes && attributes.length) {
        // Find the attribute for output type if there is one.
        for (let i = 0; i < attributes.length; i++) {
          if (attributes[i].key === 'T') {
            // Note whether the output type is an integer.
            let outputType = attributes[i].value['type'];
            shouldRoundOnesDigit =
              outputType && /^DT_(BOOL|INT|UINT)/.test(outputType);
            break;
          }
        }
      }
    }
  }
  let statsSvg = document.createElementNS(SVG_NAMESPACE, 'text');
  if (Number.isFinite(numericStats.min) && Number.isFinite(numericStats.max)) {
    const minString = humanizeHealthPillStat(
      numericStats.min,
      shouldRoundOnesDigit
    );
    const maxString = humanizeHealthPillStat(
      numericStats.max,
      shouldRoundOnesDigit
    );
    if (totalCount > 1) {
      statsSvg.textContent = minString + ' ~ ' + maxString;
    } else {
      statsSvg.textContent = minString;
    }
    if (nanCount > 0 || negInfCount > 0 || posInfCount > 0) {
      statsSvg.textContent += ' (';
      const badValueStrings: string[] = [];
      if (nanCount > 0) {
        badValueStrings.push(`NaN×${nanCount}`);
      }
      if (negInfCount > 0) {
        badValueStrings.push(`-∞×${negInfCount}`);
      }
      if (posInfCount > 0) {
        badValueStrings.push(`+∞×${posInfCount}`);
      }
      statsSvg.textContent += badValueStrings.join('; ') + ')';
    }
  } else {
    statsSvg.textContent = '(No finite elements)';
  }
  statsSvg.classList.add('health-pill-stats');
  if (textXOffset == null) {
    textXOffset = healthPillWidth / 2;
  }
  statsSvg.setAttribute('x', String(textXOffset));
  statsSvg.setAttribute('y', String(healthPillYOffset - 2));
  healthPillGroup.appendChild(statsSvg);
  (PolymerDom.dom(nodeGroupElement.parentNode) as any).appendChild(
    healthPillGroup
  );
}