in tensorflow/tensorflow/core/grappler/optimizers/constant_folding.cc [2843:3016]
bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
NodeDef* node) {
// Consider the transformation
//
// + + = parent
// / \ / \
// C + -- > X + = children
// / \ / \
// X Y C Y = leaves
//
// where C is constant, X is non-constant, Y may be constant or non-constant,
// and '+' denotes an associative and commutative operator like addition or
// multiplication. This optimization pushes constants down in the tree to
// canonicalize it. Moreoever, in cases where the child node has a second
// constant input Y we will create a leaf node that can be folded, e.g.
//
// Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
//
// We also handle the non-commutative cases of subtraction and division
// by rotating the tree locally, e.g.
// Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
// Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
//
// Note: Don't touch BiasAdd since they can't handle vectors as their first
// inputs.
// Get parent op type.
const bool is_add = IsAdd(*node);
const bool is_mul = IsMul(*node);
const bool is_sub = IsSub(*node);
const bool is_div = IsDiv(*node);
const bool is_symmetric = is_add || is_mul;
if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
NumNonControlInputs(*node) != 2) {
return false;
}
NodeDef* left_child = node_map_->GetNode(node->input(0));
NodeDef* right_child = node_map_->GetNode(node->input(1));
const bool left_child_is_constant = IsReallyConstant(*left_child);
const bool right_child_is_constant = IsReallyConstant(*right_child);
if (!left_child_is_constant && !right_child_is_constant) {
return false;
}
// Don't move nodes across devices.
if (node->device() != left_child->device() ||
node->device() != right_child->device()) {
return false;
}
NodeDef* op_child = left_child_is_constant ? right_child : left_child;
NodeDef* const_child = left_child_is_constant ? left_child : right_child;
// Don't rewrite the tree if it might create cycles.
// TODO(rmlarsen): Add back handling of control dependency from op to C.
const auto& child_output = node_map_->GetOutputs(op_child->name());
if (child_output.find(const_child) != child_output.end()) {
return false;
}
// Get child op type.
const bool is_child_add = IsAdd(*op_child);
const bool is_child_mul = IsMul(*op_child);
const bool is_child_sub = IsSub(*op_child);
const bool is_child_div = IsDiv(*op_child);
const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
if (!is_add_sub && !is_mul_div) {
return false;
}
const bool is_child_symmetric = is_child_add || is_child_mul;
// Make sure that it is safe to change the value of the child node result.
if (op_child->input_size() < 2 ||
nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
NumNonControlOutputs(*op_child, *node_map_) > 1) {
return false;
}
// Do not rewrite integer expressions with subtraction or division.
// if (node->name().find("filter_boxes") != std::string::npos) return false;
if (!CheckAttrExists(*node, "T").ok()) return false;
DataType dtype = node->attr().at("T").type();
if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
// Don't apply reassociation to floating point types of low precision.
// The danger of significant numerical changes is too high.
return false;
}
if (!(is_symmetric && is_child_symmetric) &&
!(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
return false;
}
// Identify the nodes to swap.
NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
if (left_leaf_is_constant && right_leaf_is_constant) {
// Child is already foldable, leave it alone.
return false;
}
// Don't move nodes across devices.
if (node->device() != left_leaf->device() ||
node->device() != right_leaf->device()) {
return false;
}
// Get the node names corresponding to X, Y, and C.
const string input_x =
left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
const string input_y =
input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
const string input_c =
left_child_is_constant ? node->input(0) : node->input(1);
const string input_op =
left_child_is_constant ? node->input(1) : node->input(0);
VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
<< "(" << left_child->op() << ", " << right_child->op() << ")\n";
// Now we have identified the nodes to swap (non_const_leaf_input and
// const_child).
node_map_->UpdateInput(node->name(), input_c, input_x);
node_map_->AddOutput(input_c, op_child->name());
if (input_x != input_y) {
node_map_->RemoveOutput(input_x, op_child->name());
}
if (is_symmetric && is_child_symmetric) {
// Easy case (only commutative ops). We always write this as one of
// +
// / \
// X +
// / \
// C Y
node->set_input(0, input_x);
node->set_input(1, input_op);
op_child->set_input(0, input_c);
op_child->set_input(1, input_y);
} else {
// More complicated case: When there are non-commutative operations like
// subtractions or divisions involved, we may have to rotate the tree
// and/or change op types. There are 6 non-trivial cases depending on
// the effective generalized "sign" of each of the three terms C, Y, and X.
// Here are the final trees we want to generate for those 6 cases:
//
// (CYX signs): ++- +-- -+- --+ +-+ -++
// - - - - + +
// / \ / \ / \ / \ / \ / \
// + X - X - X X + X - X -
// / \ / \ / \ / \ / \ / \
// C Y C Y Y C Y C C Y Y C
//
// First, let's determine the effective sign of each term in the original
// expression
auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
bool leaf_negated = !is_child_symmetric && is_right_leaf;
bool child_negated = !is_symmetric && (op_child == right_child);
return leaf_negated != child_negated;
};
const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
bool neg_c = !is_symmetric && (const_child == right_child);
bool neg_x = is_leaf_negated(left_leaf_is_constant);
bool neg_y = is_leaf_negated(!left_leaf_is_constant);
// Rewrite the parent node.
node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
node->set_input(0, neg_x ? input_op : input_x);
node->set_input(1, neg_x ? input_x : input_op);
// Rewrite the child node.
op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
op_child->set_input(0, neg_c ? input_y : input_c);
op_child->set_input(1, neg_c ? input_c : input_y);
}
return true;
}