in bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/bindings/python/src/gtn/functions/compose.cpp [377:522]
Graph compose(
const Graph& first,
const Graph& second,
std::shared_ptr<ArcMatcher> matcher) {
// Compute reachable nodes from any accept state in the new graph
auto reachable = findReachable(first, second, matcher);
// Compose the graphs
Graph ngraph(nullptr, {first, second});
// Flat representation of nodes in both graphs, indexed using toIndex
std::vector<int> newNodes(first.numNodes() * second.numNodes(), -1);
std::queue<std::pair<int, int>> toExplore;
// Compile starting nodes that are reachable. If any pairs of reachable start
// nodes in the input graph are also both accept nodes, make these accept
// nodes in the composed graph.
for (auto s1 : first.start()) {
for (auto s2 : second.start()) {
auto idx = toIndex(s1, s2, first);
if (reachable[idx]) {
newNodes[idx] =
ngraph.addNode(true, first.isAccept(s1) && second.isAccept(s2));
toExplore.emplace(s1, s2);
}
}
}
// The index of a particlar pair entry in gradInfo corresponds to an arc in
// the composed graph - at gradient computation time, this facilitates
// efficiently mapping an arc in the composed graph to the corresponding arcs
// in the first and second graphs
std::vector<std::pair<int, int>> gradInfo;
// Explore the graph starting from the collection of start nodes
while (!toExplore.empty()) {
auto curr = toExplore.front();
toExplore.pop();
// A node in the composed graph
auto currNode = newNodes[toIndex(curr.first, curr.second, first)];
int i, j;
bool epsilon_matched = false;
matcher->match(curr.first, curr.second);
// Each pair of nodes in the initial graph may have multiple outgoing arcs
// that should be combined in the composed graph
while (matcher->hasNext()) {
// The matcher invariant remains: arc i's olabel (from the first graph) is
// arc j's ilabel (from the second graph)
std::tie(i, j) = matcher->next();
// Ignore direct epsilon matches
if (first.olabel(i) == epsilon) {
epsilon_matched = true;
continue;
}
bool isReachable = addReachableNodeAndArc(
first,
second,
currNode,
std::make_pair(first.dstNode(i), second.dstNode(j)),
first.weight(i) + second.weight(j),
first.ilabel(i),
second.olabel(j),
reachable,
toExplore,
newNodes,
ngraph);
if (isReachable) {
// Arcs remember where they came from for easy gradient computation.
gradInfo.emplace_back(i, j);
}
}
// The logic of when to check for epsilon transitions is as follows:
// Case 1: No epsilon match.
// If there was no epsilon match then at most one of the two graphs has
// an epsilon transition and we can check both safely.
//
// Case 2: Epsilon match.
// If there was an epsilon match then we have to be careful to avoid
// redundant paths.
// 1. Follow the epsilon transition out of the non accepting node.
// 2. If both nodes are accepting follow both transitions.
// 3. If neither node is accepting (arbitrarily) follow only the first
// node's transition.
if (!epsilon_matched || second.isAccept(curr.second) || !first.isAccept(curr.first)) {
addEpsilonReachableNodes(
false,
first,
second,
currNode, // in the composed graph
curr, // in the input graphs
reachable,
toExplore,
newNodes,
ngraph,
gradInfo);
}
// Check for input epsilons in the second graph
if (!epsilon_matched || first.isAccept(curr.first)) {
addEpsilonReachableNodes(
true,
first,
second,
currNode, // in the composed graph
curr, // in the input graphs
reachable,
toExplore,
newNodes,
ngraph,
gradInfo);
}
}
/*
* Here we assume deltas is the output (e.g. ngraph) and we know where
* each arc came from. This makes it possible to disambiguate two arcs in the
* composed graph with the same label and the same src and destination nodes.
*/
auto gradFunc = [gradInfo = std::move(gradInfo)](
std::vector<Graph>& inputs, Graph deltas) {
// In this case the arc's parents are always from the
// first and second input graphs respectively.
bool calcGrad1 = inputs[0].calcGrad();
bool calcGrad2 = inputs[1].calcGrad();
auto grad1 = calcGrad1 ? std::vector<float>(inputs[0].numArcs(), 0.0)
: std::vector<float>{};
auto grad2 = calcGrad2 ? std::vector<float>(inputs[1].numArcs(), 0.0)
: std::vector<float>{};
for (int i = 0; i < gradInfo.size(); i++) {
auto arcGrad = deltas.weight(i);
auto& arcs = gradInfo[i];
if (calcGrad1 && arcs.first >= 0) {
grad1[arcs.first] += arcGrad;
}
if (calcGrad2 && arcs.second >= 0) {
grad2[arcs.second] += arcGrad;
}
}
inputs[0].addGrad(std::move(grad1));
inputs[1].addGrad(std::move(grad2));
};
ngraph.setGradFunc(std::move(gradFunc));
return ngraph;
}