async fetch()

in neuron_viewer/src/TransformerDebugger/requests/inferenceDataFetcher.ts [141:219]


  async fetch(
    modelInfo: ModelInfoResponse | null,
    commonInferenceParams: CommonInferenceParams,
    leftPromptInferenceParams: PromptInferenceParams,
    rightPromptInferenceParams: PromptInferenceParams | null,
    setRightResponse: React.Dispatch<React.SetStateAction<InferenceResponseAndResponseDict | null>>,
    setLeftResponse: React.Dispatch<React.SetStateAction<InferenceResponseAndResponseDict | null>>,
    setRightRequest: React.Dispatch<React.SetStateAction<TdbRequestSpec | null>>,
    setLeftRequest: React.Dispatch<React.SetStateAction<TdbRequestSpec | null>>,
    setActivationServerErrorMessage: React.Dispatch<React.SetStateAction<string | null>>
  ) {
    if (modelInfo === null) {
      return;
    }
    setRightResponse(null);
    setLeftResponse(null);
    const handleInferenceError = this.handleInferenceError;
    function performInference(
      subRequests: TdbRequestSpec[],
      setResponseFns: React.Dispatch<
        React.SetStateAction<InferenceResponseAndResponseDict | null>
      >[]
    ) {
      batchedTdb({ subRequests })
        .then((responseData) => {
          if (responseData.inferenceSubResponses.length !== subRequests.length) {
            throw new Error(
              "Expected exactly " +
                subRequests.length +
                " inferenceSubResponses, but got " +
                responseData.inferenceSubResponses.length
            );
          }
          for (let i = 0; i < responseData.inferenceSubResponses.length; i++) {
            setResponseFns[i](responseData.inferenceSubResponses[i]);
          }
          setActivationServerErrorMessage(null);
        })
        .catch((error) => handleInferenceError(error, setActivationServerErrorMessage));
    }

    const newLeftRequest = buildTdbRequestSpec(leftPromptInferenceParams, commonInferenceParams);
    setLeftRequest(newLeftRequest);
    var newRightRequest = null;
    if (rightPromptInferenceParams !== null) {
      // We're comparing two prompts. The right request covers the right prompt and both prompts use
      // the same ablations (which are stored on the left).
      newRightRequest = buildTdbRequestSpec(
        {
          ...rightPromptInferenceParams,
          nodeAblations: leftPromptInferenceParams.nodeAblations,
          upstreamNodeToTrace: leftPromptInferenceParams.upstreamNodeToTrace,
        },
        commonInferenceParams
      );
    } else if (leftPromptInferenceParams.nodeAblations.length !== 0) {
      newRightRequest = buildTdbRequestSpec(
        {
          ...leftPromptInferenceParams,
          // The right request omits the ablations specified in the left request.
          nodeAblations: Array<NodeAblation>(),
        },
        commonInferenceParams
      );
    } else {
      // If there is no right prompt and no ablations, there is no need to make a separate right
      // request.
      newRightRequest = null;
    }
    setRightRequest(newRightRequest);

    const subRequests = [newLeftRequest];
    const setResponseFns = [setLeftResponse];
    if (newRightRequest !== null) {
      subRequests.push(newRightRequest);
      setResponseFns.push(setRightResponse);
    }
    performInference(subRequests, setResponseFns);
  }