export function getDerivedScalarType()

in neuron_viewer/src/requests/readRequests.ts [27:75]


export function getDerivedScalarType(
  nodeType: NodeType,
  online: boolean = false
): DerivedScalarType {
  switch (nodeType) {
    case NodeType.MLP_NEURON:
      return DerivedScalarType.MLP_POST_ACT;
    case NodeType.AUTOENCODER_LATENT:
      return online
        ? DerivedScalarType.ONLINE_AUTOENCODER_LATENT
        : DerivedScalarType.AUTOENCODER_LATENT;
    case NodeType.MLP_AUTOENCODER_LATENT:
      return online
        ? DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT
        : DerivedScalarType.MLP_AUTOENCODER_LATENT;
    case NodeType.ATTENTION_AUTOENCODER_LATENT:
      return online
        ? DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT
        : DerivedScalarType.ATTENTION_AUTOENCODER_LATENT;
    // For DSTs per token pair (e.g. in attention heads), we use the unflattened DST for online
    // requests, and the flattened DST for offline requests. This is because the online requests
    // are made to the activation server, which expects the unflattened DST, and the offline
    // requests are made to the neuron records, which store the flattened DST.
    case NodeType.ATTENTION_HEAD:
      return online
        ? DerivedScalarType.UNFLATTENED_ATTN_WRITE_NORM
        : DerivedScalarType.ATTN_WRITE_NORM;
    case NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR:
      return online
        ? DerivedScalarType.ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS
        : DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS;
    case NodeType.LAYER:
      assert(false, "getDerivedScalarType should not be called on a layer node");
      break;
    case NodeType.RESIDUAL_STREAM_CHANNEL:
      return DerivedScalarType.RESID_POST_MLP;
    case NodeType.VOCAB_TOKEN:
      assert(false, "getDerivedScalarType should not be called on a vocab token node");
      break;
    case NodeType.QK_CHANNEL:
      assert(false, "getDerivedScalarType should not be called on a qk channel node");
      break;
    case NodeType.V_CHANNEL:
      assert(false, "getDerivedScalarType should not be called on a v channel node");
      break;
    default:
      return assertUnreachable(nodeType);
  }
}