sae-viewer/src/components/tokenAblationmap.tsx (101 lines of code) (raw):
import React from "react"
import { interpolateColor, Color, getInterpolatedColor, DEFAULT_COLORS, SequenceInfo } from '../types'
import Tooltip from './tooltip'
import { scaleLinear } from "d3-scale"
type Props = {
info: SequenceInfo,
colors?: Color[],
boundaries?: number[],
renderNewlines?: boolean,
}
export const normalizeToUnitInterval = (arr: number[]) => {
const max = Math.max(...arr);
const min = Math.min(...arr);
const max_abs = Math.max(Math.abs(max), Math.abs(min));
const rescale = scaleLinear()
// Even though we're only displaying positive activations, we still need to scale in a way that
// accounts for the existence of negative activations, since our color scale includes them.
.domain([-max_abs, max_abs])
.range([-1, 1])
return arr.map((x) => rescale(x));
}
export default function TokenAblationmap({ info, colors = DEFAULT_COLORS, renderNewlines }: Props) {
// <div className="block" style={{width:'100%', whiteSpace: 'pre', overflowX: 'scroll' }}>
if (!info.ablate_loss_diff) {
return <> </>;
}
const lossDiffsNorm = normalizeToUnitInterval(info.ablate_loss_diff.map((x) => (-x)));
return (
<div className="block" style={{width:'100%', whiteSpace: 'pre-wrap'}}>
{info.tokens.map((token, idx) => {
const highlight = idx === info.idx;
const loss_diff = (idx === 0) ? 0: info.ablate_loss_diff[idx-1];
const kl = (idx === 0) ? 0: info.kl[idx-1];
const activation = info.acts[idx];
const top_downvotes = (idx === 0) ? [] : info.top_downvotes_logits[idx-1];
const top_downvote_tokens = (idx === 0) ? [] : info.top_downvote_tokens_logits[idx-1];
const top_upvotes = (idx === 0) ? [] : info.top_upvotes_logits[idx-1];
const top_upvote_tokens = (idx === 0) ? [] : info.top_upvote_tokens_logits[idx-1];
// const top_downvotes_weighted = (idx === 0) ? [] : info.top_downvotes_weighted[idx-1];
// const top_downvote_tokens_weighted = (idx === 0) ? [] : info.top_downvote_tokens_weighted[idx-1];
// const top_upvotes_weighted = (idx === 0) ? [] : info.top_upvotes_weighted[idx-1];
// const top_upvote_tokens_weighted = (idx === 0) ? [] : info.top_upvote_tokens_weighted[idx-1];
const top_downvotes_probs = (idx === 0) ? [] : info.top_downvotes_probs[idx-1];
const top_downvote_tokens_probs = (idx === 0) ? [] : info.top_downvote_tokens_probs[idx-1];
const top_upvotes_probs = (idx === 0) ? [] : info.top_upvotes_probs[idx-1];
const top_upvote_tokens_probs = (idx === 0) ? [] : info.top_upvote_tokens_probs[idx-1];
const color = getInterpolatedColor(colors, [-1, 0, 1], (idx === 0) ? 0 : lossDiffsNorm[idx-1]);
if (!renderNewlines) {
token = token.replace(/\n/g, '↵')
}
return <Tooltip
content={
<span
style={{
background: `rgba(${color.r}, ${color.g}, ${color.b}, 0.5)`,
border: highlight ? '2px solid gray' : 'none',
borderRadius: '2px',
}}
>
{token}
</span>
}
tooltip={(idx <= info.idx) ? <div>(prediction prior to ablated token)</div> : <div>
Loss diff: {loss_diff.toExponential(2)} <br/>
KL(clean || ablated): {kl.toExponential(2)} <br/>
Logit diffs:
<table>
<thead>
<tr>
{
['', /*' (weighted)',*/ ' (probs)'].map((suffix, i) => {
return <React.Fragment key={i}>
<th>Upvoted {suffix}</th><th></th>
<th>Downvoted {suffix}</th><th></th>
</React.Fragment>
})
}
</tr>
</thead>
<tbody>
{
top_upvotes.map((upvote, j) => {
const downvote = top_downvotes[j];
return <tr key={j}>
<td>{upvote.toExponential(1)}</td>
<td style={{whiteSpace: 'pre'}}>{top_upvote_tokens[j]}</td>
<td>{downvote.toExponential(1)}</td>
<td style={{whiteSpace: 'pre'}}>{top_downvote_tokens[j]}</td>
<td>{top_upvotes_probs[j].toExponential(1)}</td>
<td style={{whiteSpace: 'pre'}}>{top_upvote_tokens_probs[j]}</td>
<td>{top_downvotes_probs[j].toExponential(1)}</td>
<td style={{whiteSpace: 'pre'}}>{top_downvote_tokens_probs[j]}</td>
</tr>
})
}
</tbody>
</table>
</div>}
key={idx}
/>
})}
</div>
)
}