function useCollatedNodeInfoWithExplanations()

in neuron_viewer/src/TransformerDebugger/cards/node_table/NodeTable.tsx [132:262]


function useCollatedNodeInfoWithExplanations(
  rightResponseData: MultipleTopKDerivedScalarsResponseData | null,
  leftResponseData: MultipleTopKDerivedScalarsResponseData,
  leftInferenceAndTokenData: InferenceAndTokenData,
  rightInferenceAndTokenData: InferenceAndTokenData | null,
  leftTopTokensBySpecName: Record<TopTokensSpecName, TopTokens[]> | null,
  rightTopTokensBySpecName: Record<TopTokensSpecName, TopTokens[]> | null,
  leftTokenPairAttribution: Array<TopTokensAttendedTo> | null,
  rightTokenPairAttribution: Array<TopTokensAttendedTo> | null,
  explanationMap: ExplanationMap,
  tokenIndexOfInterest: number,
  commonInferenceParams: CommonInferenceParams
) {
  const [collatedNodeInfo] = React.useMemo(() => {
    if (rightResponseData) {
      assertNodeIndicesMatchExactly(leftResponseData.nodeIndices, rightResponseData.nodeIndices);
    }
    const collatedNodeInfo: NodeInfo[] = [];

    const maxAbsByMetric: Record<Metric, number> = {} as Record<Metric, number>;
    for (let i = 0; i < leftResponseData.nodeIndices.length; i++) {
      const nodeIndex = leftResponseData.nodeIndices[i];
      let nodeInfo = createPartialNodeInfo(nodeIndex, tokenIndexOfInterest);

      // For attention-write autoencoder latents, we want to consider them as token-pair nodes.
      // This is used both for the viewer link, and the explanation fetching.
      if (nodeInfo.nodeType === NodeType.ATTENTION_AUTOENCODER_LATENT) {
        // change nodeInfo.nodeType, which is used for the viewer link
        nodeInfo.nodeType = NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR;
        // change nodeIndex.nodeIndex.nodeType, which is used for the explanation fetching (not changing
        // nodeInfo.nodeIndex.nodeType as it creates conflicts elsewhere, but creating a new nodeIndex instead)
        nodeInfo.nodeIndex = {
          ...nodeIndex,
          nodeType: NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR,
        };
      }

      METRICS.forEach((metric) => {
        const leftValue = leftResponseData.activationsByGroupId[GROUP_ID_BY_METRIC[metric]][i];
        const rightValue = rightResponseData?.activationsByGroupId[GROUP_ID_BY_METRIC[metric]][i];
        const diffValue = diffOptionalNumbers(leftValue, rightValue);
        maxAbsByMetric[metric] = Math.max(
          maxAbsByMetric[metric] || 0,
          Math.abs(leftValue ?? 0),
          Math.abs(rightValue ?? 0)
        );

        nodeInfo.metrics[metric] = {
          left: leftValue,
          right: rightValue,
          diff: diffValue,
          // maxAbs is set later
        };
      });
      if (nodeInfo.attendedToSequenceTokenIndex !== null) {
        nodeInfo.leftAttendedToTokenAsString =
          leftInferenceAndTokenData.tokensAsStrings[nodeInfo.attendedToSequenceTokenIndex!];
      }
      if (leftTokenPairAttribution !== null && leftTokenPairAttribution[i] !== null) {
        const tokenAttendedToIndex = leftTokenPairAttribution[i].tokenIndices[0];
        nodeInfo.leftAttributedToSequenceTokenIndex = tokenAttendedToIndex;
        nodeInfo.leftAttendedToTokenAsString =
          leftInferenceAndTokenData.tokensAsStrings[tokenAttendedToIndex];
      }
      nodeInfo.leftAttendedFromTokenAsString =
        leftInferenceAndTokenData.tokensAsStrings[nodeInfo.sequenceTokenIndex];
      if (leftTopTokensBySpecName !== null) {
        nodeInfo.leftTopTokensBySpecName = {};
        TOP_TOKENS_SPEC_NAMES.forEach((specName) => {
          nodeInfo.leftTopTokensBySpecName![specName] = leftTopTokensBySpecName[specName][i];
        });
      }

      if (rightResponseData) {
        if (rightInferenceAndTokenData !== null) {
          if (nodeInfo.attendedToSequenceTokenIndex !== null) {
            nodeInfo.rightAttendedToTokenAsString =
              rightInferenceAndTokenData.tokensAsStrings[nodeInfo.attendedToSequenceTokenIndex!];
          }
          if (rightTokenPairAttribution !== null && rightTokenPairAttribution[i] !== null) {
            const tokenAttendedToIndex = rightTokenPairAttribution[i].tokenIndices[0];
            nodeInfo.rightAttributedToSequenceTokenIndex = tokenAttendedToIndex;
            nodeInfo.rightAttendedToTokenAsString =
              rightInferenceAndTokenData.tokensAsStrings[tokenAttendedToIndex];
          }
          nodeInfo.rightAttendedFromTokenAsString =
            rightInferenceAndTokenData.tokensAsStrings[nodeInfo.sequenceTokenIndex];
        }
        if (rightTopTokensBySpecName !== null) {
          nodeInfo.rightTopTokensBySpecName = {};
          TOP_TOKENS_SPEC_NAMES.forEach((specName) => {
            nodeInfo.rightTopTokensBySpecName![specName] = rightTopTokensBySpecName[specName][i];
          });
        }
      }
      collatedNodeInfo.push(nodeInfo);
    }
    for (let i = 0; i < collatedNodeInfo.length; i++) {
      METRICS.forEach((metric) => {
        collatedNodeInfo[i].metrics[metric].maxAbs = maxAbsByMetric[metric];
      });
    }
    return [collatedNodeInfo];
    // Can't pass top tokens by spec name. Doing so causes the component to re-render infinitely.
    // It's safe to omit it because it only changes when the response data changes, and the response
    // data objects are already included in the dependency array.
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, [
    rightResponseData,
    leftResponseData,
    rightInferenceAndTokenData,
    leftInferenceAndTokenData,
    tokenIndexOfInterest,
  ]);

  const collatedNodeInfoWithExplanations = React.useMemo(() => {
    const collatedNodeInfoWithExplanations = collatedNodeInfo.map((nodeInfo) => {
      const key = nodeToStringKey(nodeFromNodeIndex(nodeInfo.nodeIndex));
      const explanationEntry = explanationMap.get(key);
      if (explanationEntry === undefined) {
        return nodeInfo;
      }
      return {
        ...nodeInfo,
        explanationEntry,
      };
    });
    return collatedNodeInfoWithExplanations;
  }, [collatedNodeInfo, explanationMap]);
  return collatedNodeInfoWithExplanations;
}