export function useExplanationFetcher()

in neuron_viewer/src/TransformerDebugger/requests/explanationFetcher.ts [18:111]


export function useExplanationFetcher(maxRequestsInProgress: number = 2): {
  // Explanations that have been retrieved so far.
  explanationMap: ExplanationMap;
  // Function to set which nodes we're requesting explanations for.
  setNodesRequestingExplanation: React.Dispatch<React.SetStateAction<Node[]>>;
} {
  // Custom hook that keeps a map of explanations and takes requests to retrieve more explanations.
  // Limits number of requests in progress to maxRequestsInProgress.
  const [explanationMap, setExplanationMap] = useState<ExplanationMap>(new Map());
  const [nodesRequestingExplanation, setNodesRequestingExplanation] = useState<Node[]>([]);
  const explainAndScoreAsync = useCallback(async (node: Node) => {
    try {
      const explainResult = await explain(node);
      setExplanationMap((prevExplanationMap: ExplanationMap) => {
        const newExplanationsMap = new Map(prevExplanationMap);
        newExplanationsMap.set(nodeToStringKey(node), {
          state: "in_progress",
          scoredExplanations: explainResult.explanations.map((explanation) => ({
            datasetName: explainResult.dataset,
            explanation: explanation,
          })),
        });
        return newExplanationsMap;
      });
      if (explainResult.explanations.length === 0) {
        return;
      }
      const scoreResult = await scoreExplanation(node, explainResult.explanations[0]);
      setExplanationMap((prevExplanationMap: ExplanationMap) => {
        const newExplanationsMap = new Map(prevExplanationMap);
        const prevEntry = newExplanationsMap.get(nodeToStringKey(node));
        if (prevEntry === undefined || prevEntry.scoredExplanations == null) {
          return newExplanationsMap;
        }
        let newExplanationList = [...prevEntry.scoredExplanations];
        newExplanationList[0].score = scoreResult.score;
        newExplanationsMap.set(nodeToStringKey(node), {
          state: "success",
          scoredExplanations: newExplanationList,
        });
        return newExplanationsMap;
      });
    } catch (error) {
      // This catch covers errors from both explain and scoreExplanation.
      console.log(error);
      setExplanationMap((prevExplanationMap: ExplanationMap) => {
        const newExplanationsMap = new Map(prevExplanationMap);
        newExplanationsMap.set(nodeToStringKey(node), {
          state: "error",
          scoredExplanations: null,
        });
        return newExplanationsMap;
      });
    }
  }, []);
  useEffect(() => {
    // count entries in explanations that are in in_progress
    let inProgressCount = 0;
    explanationMap.forEach((value) => {
      if (value.state === "in_progress") {
        inProgressCount++;
      }
    });
    const requestsToStart = maxRequestsInProgress - inProgressCount;
    if (requestsToStart <= 0) {
      return;
    }

    // find the first requestsToStart nodes that are not in explanations
    // ensure nodes are unique by converting to a Set
    const uniqueNodes = new Set(nodesRequestingExplanation);
    const nodesToRequest = Array.from(uniqueNodes).filter((node) => {
      return !explanationMap.has(nodeToStringKey(node));
    });
    // start requestsToStart requests
    for (let i = 0; i < Math.min(requestsToStart, nodesToRequest.length); i++) {
      const node = nodesToRequest[i];
      console.log("starting explanation request for", node);
      setExplanationMap((prevExplanationMap: ExplanationMap) => {
        const newExplanations = new Map(prevExplanationMap);
        newExplanations.set(nodeToStringKey(node), {
          state: "in_progress",
          scoredExplanations: [],
        });
        return newExplanations;
      });
      explainAndScoreAsync(node);
    }
  }, [explanationMap, explainAndScoreAsync, nodesRequestingExplanation, maxRequestsInProgress]);
  return {
    explanationMap,
    setNodesRequestingExplanation,
  };
}