Graph compose()

in bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/gtn/functions/compose.cpp [377:522]


Graph compose(
    const Graph& first,
    const Graph& second,
    std::shared_ptr<ArcMatcher> matcher) {
  // Compute reachable nodes from any accept state in the new graph
  auto reachable = findReachable(first, second, matcher);

  // Compose the graphs
  Graph ngraph(nullptr, {first, second});
  // Flat representation of nodes in both graphs, indexed using toIndex
  std::vector<int> newNodes(first.numNodes() * second.numNodes(), -1);
  std::queue<std::pair<int, int>> toExplore;
  // Compile starting nodes that are reachable. If any pairs of reachable start
  // nodes in the input graph are also both accept nodes, make these accept
  // nodes in the composed graph.
  for (auto s1 : first.start()) {
    for (auto s2 : second.start()) {
      auto idx = toIndex(s1, s2, first);
      if (reachable[idx]) {
        newNodes[idx] =
            ngraph.addNode(true, first.isAccept(s1) && second.isAccept(s2));
        toExplore.emplace(s1, s2);
      }
    }
  }

  // The index of a particlar pair entry in gradInfo corresponds to an arc in
  // the composed graph - at gradient computation time, this facilitates
  // efficiently mapping an arc in the composed graph to the corresponding arcs
  // in the first and second graphs
  std::vector<std::pair<int, int>> gradInfo;
  // Explore the graph starting from the collection of start nodes
  while (!toExplore.empty()) {
    auto curr = toExplore.front();
    toExplore.pop();
    // A node in the composed graph
    auto currNode = newNodes[toIndex(curr.first, curr.second, first)];
    int i, j;
    bool epsilon_matched = false;
    matcher->match(curr.first, curr.second);
    // Each pair of nodes in the initial graph may have multiple outgoing arcs
    // that should be combined in the composed graph
    while (matcher->hasNext()) {
      // The matcher invariant remains: arc i's olabel (from the first graph) is
      // arc j's ilabel (from the second graph)
      std::tie(i, j) = matcher->next();

      // Ignore direct epsilon matches
      if (first.olabel(i) == epsilon) {
        epsilon_matched = true;
        continue;
      }

      bool isReachable = addReachableNodeAndArc(
          first,
          second,
          currNode,
          std::make_pair(first.dstNode(i), second.dstNode(j)),
          first.weight(i) + second.weight(j),
          first.ilabel(i),
          second.olabel(j),
          reachable,
          toExplore,
          newNodes,
          ngraph);

      if (isReachable) {
        // Arcs remember where they came from for easy gradient computation.
        gradInfo.emplace_back(i, j);
      }
    }

    // The logic of when to check for epsilon transitions is as follows:
    // Case 1: No epsilon match.
    //   If there was no epsilon match then at most one of the two graphs has
    //   an epsilon transition and we can check both safely.
    //
    // Case 2: Epsilon match.
    //   If there was an epsilon match then we have to be careful to avoid
    //   redundant paths.
    //   1. Follow the epsilon transition out of the non accepting node.
    //   2. If both nodes are accepting follow both transitions.
    //   3. If neither node is accepting (arbitrarily) follow only the first
    //   node's transition.
    if (!epsilon_matched || second.isAccept(curr.second) || !first.isAccept(curr.first)) {
      addEpsilonReachableNodes(
          false,
          first,
          second,
          currNode, // in the composed graph
          curr, // in the input graphs
          reachable,
          toExplore,
          newNodes,
          ngraph,
          gradInfo);
    }

    // Check for input epsilons in the second graph
    if (!epsilon_matched || first.isAccept(curr.first)) {
      addEpsilonReachableNodes(
          true,
          first,
          second,
          currNode, // in the composed graph
          curr, // in the input graphs
          reachable,
          toExplore,
          newNodes,
          ngraph,
          gradInfo);
    }
  }

  /*
   * Here we assume deltas is the output (e.g. ngraph) and we know where
   * each arc came from. This makes it possible to disambiguate two arcs in the
   * composed graph with the same label and the same src and destination nodes.
   */
  auto gradFunc = [gradInfo = std::move(gradInfo)](
                      std::vector<Graph>& inputs, Graph deltas) {
    // In this case the arc's parents are always from the
    // first and second input graphs respectively.
    bool calcGrad1 = inputs[0].calcGrad();
    bool calcGrad2 = inputs[1].calcGrad();
    auto grad1 = calcGrad1 ? std::vector<float>(inputs[0].numArcs(), 0.0)
                           : std::vector<float>{};
    auto grad2 = calcGrad2 ? std::vector<float>(inputs[1].numArcs(), 0.0)
                           : std::vector<float>{};
    for (int i = 0; i < gradInfo.size(); i++) {
      auto arcGrad = deltas.weight(i);
      auto& arcs = gradInfo[i];
      if (calcGrad1 && arcs.first >= 0) {
        grad1[arcs.first] += arcGrad;
      }
      if (calcGrad2 && arcs.second >= 0) {
        grad2[arcs.second] += arcGrad;
      }
    }
    inputs[0].addGrad(std::move(grad1));
    inputs[1].addGrad(std::move(grad2));
  };

  ngraph.setGradFunc(std::move(gradFunc));
  return ngraph;
}