in neuron_viewer/src/TransformerDebugger/cards/node_table/NodeTable.tsx [651:897]
const columnDefs: (ColDef<NodeInfo, any> | ColGroupDef<NodeInfo>)[] = useMemo(() => {
const defaultFloatColDefs: ColDef<NodeInfo, any> = {
valueFormatter: (params: any) => formatFloat(params.value),
resizable: true,
width: 80,
sortingOrder: ["desc", "asc"],
comparator: compareWithUndefinedAsZero,
cellStyle: (params: any): CellStyle => {
const nodeInfo = params.data as NodeInfo;
const value = (params.value as number) || 0;
const color = getInterpolatedColor(
POSITIVE_NEGATIVE_COLORS,
[-1, 0, 1],
value / getMaxAbsValueForColumn(params.column.colId, nodeInfo)
);
return { backgroundColor: `rgba(${color.r}, ${color.g}, ${color.b}, 0.5)` };
},
};
let columnDefs: (ColDef<NodeInfo, any> | ColGroupDef<NodeInfo>)[] = [
{
headerName: "Name",
// Ensure that the column never scrolls out of view when scrolling horizontally.
pinned: "left",
sortable: true,
filter: true,
field: "name",
width: 140,
cellRenderer: (params: any) => {
const nodeInfo = params.data as NodeInfo;
return (
<Link
target="_blank"
className={
nodeInfo.name.startsWith("attn")
? "text-green-500 hover:text-green-700" // attention in green
: "text-blue-500 hover:text-blue-700" // everything else in blue
}
to={`../${nodeInfo.nodeType}/${nodeInfo.layerIndex}/${
nodeInfo.activationIndex
}?promptsOfInterest=${prompts.join(PROMPTS_SEPARATOR)}`}
relative="path"
>
{nodeInfo.name}
</Link>
);
},
},
{
headerName: "Tokens",
children: [
{
headerName:
commonInferenceParams.componentTypeForAttention === "autoencoder_latent"
? "Attributed to"
: "Attended to",
headerTooltip:
commonInferenceParams.componentTypeForAttention === "autoencoder_latent"
? TOKEN_ATTRIBUTED_TO_EXPLANATION
: TOKEN_ATTENDED_TO_EXPLANATION,
field: "attendedToSequenceTokenIndex",
width: 150,
cellRenderer: (params: any) => {
const nodeInfo = params.data as NodeInfo;
// for attention heads, split by token pairs, both left and right attend to the same token index
if (nodeInfo.attendedToSequenceTokenIndex !== undefined) {
return (
<TokenCell
leftTokenAsString={nodeInfo.leftAttendedToTokenAsString}
rightTokenAsString={nodeInfo.rightAttendedToTokenAsString}
leftSequenceTokenIndex={nodeInfo.attendedToSequenceTokenIndex}
rightSequenceTokenIndex={nodeInfo.attendedToSequenceTokenIndex}
/>
);
}
// for attention latents, left and right might have different attribution for top token attended to
else if (
nodeInfo.leftAttributedToSequenceTokenIndex !== undefined ||
nodeInfo.rightAttributedToSequenceTokenIndex !== undefined
) {
return (
<TokenCell
leftTokenAsString={nodeInfo.leftAttendedToTokenAsString}
rightTokenAsString={nodeInfo.rightAttendedToTokenAsString}
leftSequenceTokenIndex={nodeInfo.leftAttributedToSequenceTokenIndex}
rightSequenceTokenIndex={nodeInfo.rightAttributedToSequenceTokenIndex}
/>
);
} else {
return "";
}
},
},
{
headerName: "Attended from",
headerTooltip: TOKEN_ATTENDED_FROM_EXPLANATION,
field: "sequenceTokenIndex",
width: 150,
cellRenderer: (params: any) => {
const nodeInfo = params.data as NodeInfo;
return (
<TokenCell
leftTokenAsString={nodeInfo.leftAttendedFromTokenAsString}
rightTokenAsString={nodeInfo.rightAttendedFromTokenAsString}
leftSequenceTokenIndex={nodeInfo.sequenceTokenIndex}
rightSequenceTokenIndex={nodeInfo.sequenceTokenIndex}
shouldHighlight={nodeInfo.sequenceTokenIndex === nodeInfo.tokenIndexOfInterest}
/>
);
},
},
],
},
];
if (rightResponseData !== null) {
columnDefs.push({
headerName: "Activation",
headerTooltip: ACTIVATION_EXPLANATION,
minWidth: 400,
children: [
{
...defaultFloatColDefs,
headerName: "Left",
field: "metrics.Activation.left",
},
{
...defaultFloatColDefs,
headerName: "Diff",
field: "metrics.Activation.diff",
},
{
...defaultFloatColDefs,
headerName: "Right",
field: "metrics.Activation.right",
},
],
});
} else {
columnDefs.push({
headerName: "Activation",
headerTooltip: ACTIVATION_EXPLANATION,
...defaultFloatColDefs,
width: 100,
sortable: true,
field: "metrics.Activation.left",
});
}
if (rightResponseData !== null) {
columnDefs.push({
headerName: "Write magnitude",
headerTooltip: WRITE_MAGNITUDE_EXPLANATION,
minWidth: 400,
children: [
{
...defaultFloatColDefs,
headerName: "Left",
field: "metrics.WriteNorm.left",
},
{
...defaultFloatColDefs,
headerName: "Diff",
field: "metrics.WriteNorm.diff",
},
{
...defaultFloatColDefs,
headerName: "Right",
field: "metrics.WriteNorm.right",
},
],
});
} else {
columnDefs.push({
headerName: "Write magnitude",
headerTooltip: WRITE_MAGNITUDE_EXPLANATION,
...defaultFloatColDefs,
width: 100,
sortable: true,
field: "metrics.WriteNorm.left",
});
}
if (rightResponseData !== null) {
columnDefs.push({
headerName: "Direct effect",
headerTooltip: DIRECTION_WRITE_EXPLANATION,
children: [
{
...defaultFloatColDefs,
headerName: "Left",
field: "metrics.DirectionWrite.left",
initialSort: "desc",
},
{
...defaultFloatColDefs,
headerName: "Diff",
field: "metrics.DirectionWrite.diff",
},
{
...defaultFloatColDefs,
headerName: "Right",
field: "metrics.DirectionWrite.right",
},
],
});
} else {
columnDefs.push({
headerName: "Direct effect",
headerTooltip: DIRECTION_WRITE_EXPLANATION,
...defaultFloatColDefs,
width: 100,
sortable: true,
field: "metrics.DirectionWrite.left",
initialSort: "desc",
});
}
if (rightResponseData !== null) {
columnDefs.push({
headerName: "Estimated total effect",
headerTooltip: ACT_TIMES_GRAD_EXPLANATION,
children: [
{
...defaultFloatColDefs,
headerName: "Left",
field: "metrics.ActTimesGrad.left",
},
{
...defaultFloatColDefs,
headerName: "Diff",
field: "metrics.ActTimesGrad.diff",
},
{
...defaultFloatColDefs,
headerName: "Right",
field: "metrics.ActTimesGrad.right",
},
],
});
} else {
columnDefs.push({
...defaultFloatColDefs,
headerName: "Estimated total effect",
headerTooltip: ACT_TIMES_GRAD_EXPLANATION,
field: "metrics.ActTimesGrad.left",
width: 150,
sortable: true,
});
}
return columnDefs;
}, [rightResponseData, prompts, commonInferenceParams.componentTypeForAttention]);