neuron_viewer/src/TransformerDebugger/utils/nodes.tsx (97 lines of code) (raw):
import { MirroredNodeIndex } from "../../client";
import { NodeType, Node } from "../../types";
import _ from "lodash";
export function nodeFromNodeIndex(nodeIndex: MirroredNodeIndex): Node {
const node: Node = {
nodeType: nodeIndex.nodeType,
layerIndex: nodeIndex.layerIndex || 0,
nodeIndex: nodeIndex.tensorIndices.slice(-1)[0],
};
return node;
}
export function namedAttentionHeadIndices(nodeIndex: MirroredNodeIndex) {
if (nodeIndex.nodeType !== NodeType.ATTENTION_HEAD) {
throw new Error("Incorrect nodeType for namedAttentionHeadIndices function");
}
const [attendedFromTokenIndex, attendedToTokenIndex, attentionHeadIndex] =
nodeIndex.tensorIndices;
return { attendedFromTokenIndex, attendedToTokenIndex, attentionHeadIndex };
}
export function namedMlpNeuronIndices(nodeIndex: MirroredNodeIndex) {
if (nodeIndex.nodeType !== NodeType.MLP_NEURON) {
throw new Error("Incorrect nodeType for namedMlpNeuronIndices function");
}
const [sequenceTokenIndex, neuronIndex] = nodeIndex.tensorIndices;
return { sequenceTokenIndex, neuronIndex };
}
export function namedAutoencoderLatentIndices(nodeIndex: MirroredNodeIndex) {
const validNodeTypes = [
NodeType.AUTOENCODER_LATENT,
NodeType.MLP_AUTOENCODER_LATENT,
NodeType.ATTENTION_AUTOENCODER_LATENT,
];
if (!validNodeTypes.includes(nodeIndex.nodeType)) {
throw new Error("Incorrect nodeType for namedAutoencoderLatentIndices function");
}
const [sequenceTokenIndex, latentIndex] = nodeIndex.tensorIndices;
return { sequenceTokenIndex, latentIndex };
}
export function nodeToStringKey(node: Node): string {
return `${node.nodeType}.${node.layerIndex}.${node.nodeIndex}`;
}
export const makeNodeName = (nodeIndex: MirroredNodeIndex) => {
const activationIndex = getActivationIndex(nodeIndex);
if (nodeIndex.nodeType === NodeType.ATTENTION_HEAD) {
return `attn_L${nodeIndex.layerIndex}_${activationIndex}`;
} else if (nodeIndex.nodeType === NodeType.MLP_NEURON) {
return `mlp_L${nodeIndex.layerIndex}_${activationIndex}`;
} else if (nodeIndex.nodeType === NodeType.AUTOENCODER_LATENT) {
return `latent_L${nodeIndex.layerIndex}_${activationIndex}`;
} else if (nodeIndex.nodeType === NodeType.MLP_AUTOENCODER_LATENT) {
return `mlp_ae_L${nodeIndex.layerIndex}_${activationIndex}`;
} else if (nodeIndex.nodeType === NodeType.ATTENTION_AUTOENCODER_LATENT) {
return `attn_ae_L${nodeIndex.layerIndex}_${activationIndex}`;
} else if (nodeIndex.nodeType === NodeType.LAYER) {
return `embedding`;
} else {
console.log(`Unknown node type ${nodeIndex.nodeType}`);
return `${nodeIndex.nodeType}.${nodeIndex.layerIndex}.${activationIndex}`;
}
};
export function getSequenceTokenIndex(nodeIndex: MirroredNodeIndex): number {
return nodeIndex.tensorIndices[0];
}
export function getActivationIndex(nodeIndex: MirroredNodeIndex): number {
return nodeIndex.tensorIndices[nodeIndex.tensorIndices.length - 1];
}
export function getAttendedToSequenceTokenIndex(nodeIndex: MirroredNodeIndex): number | undefined {
if (nodeIndex.nodeType === NodeType.ATTENTION_HEAD) {
return nodeIndex.tensorIndices[1];
} else {
return undefined;
}
}
export type JointIndexLookupTable = {
nodeIndices: MirroredNodeIndex[];
rightArrayIndices: (number | undefined)[];
leftArrayIndices: (number | undefined)[];
};
export function joinIndices(
rightIndices: MirroredNodeIndex[],
leftIndices: MirroredNodeIndex[]
): JointIndexLookupTable {
// Use _.isEqual to compare index values
const nodeIndices = _.uniqWith([...rightIndices, ...leftIndices], _.isEqual);
let rightArrayIndices: (number | undefined)[] = Array<undefined>(nodeIndices.length);
for (let i = 0; i < rightIndices.length; i++) {
const rightIndex = rightIndices[i];
const index = nodeIndices.findIndex((nodeIndex) => _.isEqual(nodeIndex, rightIndex));
rightArrayIndices[index] = i;
}
let leftArrayIndices: (number | undefined)[] = Array<undefined>(nodeIndices.length);
for (let i = 0; i < leftIndices.length; i++) {
const ablatedIndex = leftIndices[i];
const index = nodeIndices.findIndex((nodeIndex) => _.isEqual(nodeIndex, ablatedIndex));
leftArrayIndices[index] = i;
}
return { nodeIndices, rightArrayIndices, leftArrayIndices };
}