in neuron_viewer/src/TransformerDebugger/cards/node_table/NodeTable.tsx [132:262]
function useCollatedNodeInfoWithExplanations(
rightResponseData: MultipleTopKDerivedScalarsResponseData | null,
leftResponseData: MultipleTopKDerivedScalarsResponseData,
leftInferenceAndTokenData: InferenceAndTokenData,
rightInferenceAndTokenData: InferenceAndTokenData | null,
leftTopTokensBySpecName: Record<TopTokensSpecName, TopTokens[]> | null,
rightTopTokensBySpecName: Record<TopTokensSpecName, TopTokens[]> | null,
leftTokenPairAttribution: Array<TopTokensAttendedTo> | null,
rightTokenPairAttribution: Array<TopTokensAttendedTo> | null,
explanationMap: ExplanationMap,
tokenIndexOfInterest: number,
commonInferenceParams: CommonInferenceParams
) {
const [collatedNodeInfo] = React.useMemo(() => {
if (rightResponseData) {
assertNodeIndicesMatchExactly(leftResponseData.nodeIndices, rightResponseData.nodeIndices);
}
const collatedNodeInfo: NodeInfo[] = [];
const maxAbsByMetric: Record<Metric, number> = {} as Record<Metric, number>;
for (let i = 0; i < leftResponseData.nodeIndices.length; i++) {
const nodeIndex = leftResponseData.nodeIndices[i];
let nodeInfo = createPartialNodeInfo(nodeIndex, tokenIndexOfInterest);
// For attention-write autoencoder latents, we want to consider them as token-pair nodes.
// This is used both for the viewer link, and the explanation fetching.
if (nodeInfo.nodeType === NodeType.ATTENTION_AUTOENCODER_LATENT) {
// change nodeInfo.nodeType, which is used for the viewer link
nodeInfo.nodeType = NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR;
// change nodeIndex.nodeIndex.nodeType, which is used for the explanation fetching (not changing
// nodeInfo.nodeIndex.nodeType as it creates conflicts elsewhere, but creating a new nodeIndex instead)
nodeInfo.nodeIndex = {
...nodeIndex,
nodeType: NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR,
};
}
METRICS.forEach((metric) => {
const leftValue = leftResponseData.activationsByGroupId[GROUP_ID_BY_METRIC[metric]][i];
const rightValue = rightResponseData?.activationsByGroupId[GROUP_ID_BY_METRIC[metric]][i];
const diffValue = diffOptionalNumbers(leftValue, rightValue);
maxAbsByMetric[metric] = Math.max(
maxAbsByMetric[metric] || 0,
Math.abs(leftValue ?? 0),
Math.abs(rightValue ?? 0)
);
nodeInfo.metrics[metric] = {
left: leftValue,
right: rightValue,
diff: diffValue,
// maxAbs is set later
};
});
if (nodeInfo.attendedToSequenceTokenIndex !== null) {
nodeInfo.leftAttendedToTokenAsString =
leftInferenceAndTokenData.tokensAsStrings[nodeInfo.attendedToSequenceTokenIndex!];
}
if (leftTokenPairAttribution !== null && leftTokenPairAttribution[i] !== null) {
const tokenAttendedToIndex = leftTokenPairAttribution[i].tokenIndices[0];
nodeInfo.leftAttributedToSequenceTokenIndex = tokenAttendedToIndex;
nodeInfo.leftAttendedToTokenAsString =
leftInferenceAndTokenData.tokensAsStrings[tokenAttendedToIndex];
}
nodeInfo.leftAttendedFromTokenAsString =
leftInferenceAndTokenData.tokensAsStrings[nodeInfo.sequenceTokenIndex];
if (leftTopTokensBySpecName !== null) {
nodeInfo.leftTopTokensBySpecName = {};
TOP_TOKENS_SPEC_NAMES.forEach((specName) => {
nodeInfo.leftTopTokensBySpecName![specName] = leftTopTokensBySpecName[specName][i];
});
}
if (rightResponseData) {
if (rightInferenceAndTokenData !== null) {
if (nodeInfo.attendedToSequenceTokenIndex !== null) {
nodeInfo.rightAttendedToTokenAsString =
rightInferenceAndTokenData.tokensAsStrings[nodeInfo.attendedToSequenceTokenIndex!];
}
if (rightTokenPairAttribution !== null && rightTokenPairAttribution[i] !== null) {
const tokenAttendedToIndex = rightTokenPairAttribution[i].tokenIndices[0];
nodeInfo.rightAttributedToSequenceTokenIndex = tokenAttendedToIndex;
nodeInfo.rightAttendedToTokenAsString =
rightInferenceAndTokenData.tokensAsStrings[tokenAttendedToIndex];
}
nodeInfo.rightAttendedFromTokenAsString =
rightInferenceAndTokenData.tokensAsStrings[nodeInfo.sequenceTokenIndex];
}
if (rightTopTokensBySpecName !== null) {
nodeInfo.rightTopTokensBySpecName = {};
TOP_TOKENS_SPEC_NAMES.forEach((specName) => {
nodeInfo.rightTopTokensBySpecName![specName] = rightTopTokensBySpecName[specName][i];
});
}
}
collatedNodeInfo.push(nodeInfo);
}
for (let i = 0; i < collatedNodeInfo.length; i++) {
METRICS.forEach((metric) => {
collatedNodeInfo[i].metrics[metric].maxAbs = maxAbsByMetric[metric];
});
}
return [collatedNodeInfo];
// Can't pass top tokens by spec name. Doing so causes the component to re-render infinitely.
// It's safe to omit it because it only changes when the response data changes, and the response
// data objects are already included in the dependency array.
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [
rightResponseData,
leftResponseData,
rightInferenceAndTokenData,
leftInferenceAndTokenData,
tokenIndexOfInterest,
]);
const collatedNodeInfoWithExplanations = React.useMemo(() => {
const collatedNodeInfoWithExplanations = collatedNodeInfo.map((nodeInfo) => {
const key = nodeToStringKey(nodeFromNodeIndex(nodeInfo.nodeIndex));
const explanationEntry = explanationMap.get(key);
if (explanationEntry === undefined) {
return nodeInfo;
}
return {
...nodeInfo,
explanationEntry,
};
});
return collatedNodeInfoWithExplanations;
}, [collatedNodeInfo, explanationMap]);
return collatedNodeInfoWithExplanations;
}