bool ConstantFolding::ConstantPushDown()

in tensorflow/tensorflow/core/grappler/optimizers/constant_folding.cc [2843:3016]


bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
                                       NodeDef* node) {
  // Consider the transformation
  //
  //                      +                +       = parent
  //                     / \              / \
  //                    C   +    -- >    X   +     = children
  //                       / \              / \
  //                      X   Y            C   Y   = leaves
  //
  // where C is constant, X is non-constant, Y may be constant or non-constant,
  // and '+' denotes an associative and commutative operator like addition or
  // multiplication. This optimization pushes constants down in the tree to
  // canonicalize it. Moreoever, in cases where the child node has a second
  // constant input Y we will create a leaf node that can be folded, e.g.
  //
  //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
  //
  // We also handle the non-commutative cases of subtraction and division
  // by rotating the tree locally, e.g.
  //    Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
  //    Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
  //
  // Note: Don't touch BiasAdd since they can't handle vectors as their first
  // inputs.

  // Get parent op type.
  const bool is_add = IsAdd(*node);
  const bool is_mul = IsMul(*node);
  const bool is_sub = IsSub(*node);
  const bool is_div = IsDiv(*node);
  const bool is_symmetric = is_add || is_mul;
  if (!has_fetch_ || !(is_add || is_sub || is_mul || is_div) ||
      NumNonControlInputs(*node) != 2) {
    return false;
  }

  NodeDef* left_child = node_map_->GetNode(node->input(0));
  NodeDef* right_child = node_map_->GetNode(node->input(1));

  const bool left_child_is_constant = IsReallyConstant(*left_child);
  const bool right_child_is_constant = IsReallyConstant(*right_child);
  if (!left_child_is_constant && !right_child_is_constant) {
    return false;
  }
  // Don't move nodes across devices.
  if (node->device() != left_child->device() ||
      node->device() != right_child->device()) {
    return false;
  }
  NodeDef* op_child = left_child_is_constant ? right_child : left_child;
  NodeDef* const_child = left_child_is_constant ? left_child : right_child;
  // Don't rewrite the tree if it might create cycles.
  // TODO(rmlarsen): Add back handling of control dependency from op to C.
  const auto& child_output = node_map_->GetOutputs(op_child->name());
  if (child_output.find(const_child) != child_output.end()) {
    return false;
  }
  // Get child op type.
  const bool is_child_add = IsAdd(*op_child);
  const bool is_child_mul = IsMul(*op_child);
  const bool is_child_sub = IsSub(*op_child);
  const bool is_child_div = IsDiv(*op_child);
  const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
  const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
  if (!is_add_sub && !is_mul_div) {
    return false;
  }

  const bool is_child_symmetric = is_child_add || is_child_mul;
  // Make sure that it is safe to change the value of the child node result.
  if (op_child->input_size() < 2 ||
      nodes_to_preserve_.find(op_child->name()) != nodes_to_preserve_.end() ||
      NumNonControlOutputs(*op_child, *node_map_) > 1) {
    return false;
  }
  // Do not rewrite integer expressions with subtraction or division.
  //  if (node->name().find("filter_boxes") != std::string::npos) return false;
  if (!CheckAttrExists(*node, "T").ok()) return false;
  DataType dtype = node->attr().at("T").type();
  if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
    // Don't apply reassociation to floating point types of low precision.
    // The danger of significant numerical changes is too high.
    return false;
  }
  if (!(is_symmetric && is_child_symmetric) &&
      !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
    return false;
  }

  // Identify the nodes to swap.
  NodeDef* left_leaf = node_map_->GetNode(op_child->input(0));
  NodeDef* right_leaf = node_map_->GetNode(op_child->input(1));
  const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
  const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
  if (left_leaf_is_constant && right_leaf_is_constant) {
    // Child is already foldable, leave it alone.
    return false;
  }
  // Don't move nodes across devices.
  if (node->device() != left_leaf->device() ||
      node->device() != right_leaf->device()) {
    return false;
  }
  // Get the node names corresponding to X, Y, and C.
  const string input_x =
      left_leaf_is_constant ? op_child->input(1) : op_child->input(0);
  const string input_y =
      input_x == op_child->input(0) ? op_child->input(1) : op_child->input(0);
  const string input_c =
      left_child_is_constant ? node->input(0) : node->input(1);
  const string input_op =
      left_child_is_constant ? node->input(1) : node->input(0);

  VLOG(1) << "\n++++++++ Reordering node " << node->name() << ": " << node->op()
          << "(" << left_child->op() << ", " << right_child->op() << ")\n";

  // Now we have identified the nodes to swap (non_const_leaf_input and
  // const_child).
  node_map_->UpdateInput(node->name(), input_c, input_x);
  node_map_->AddOutput(input_c, op_child->name());
  if (input_x != input_y) {
    node_map_->RemoveOutput(input_x, op_child->name());
  }

  if (is_symmetric && is_child_symmetric) {
    // Easy case (only commutative ops). We always write this as one of
    //   +
    //  / \
    // X   +
    //    / \
    //   C   Y
    node->set_input(0, input_x);
    node->set_input(1, input_op);
    op_child->set_input(0, input_c);
    op_child->set_input(1, input_y);
  } else {
    // More complicated case: When there are non-commutative operations like
    // subtractions or divisions involved, we may have to rotate the tree
    // and/or change op types. There are 6 non-trivial cases depending on
    // the effective generalized "sign" of each of the three terms C, Y, and X.
    // Here are the final trees we want to generate for those 6 cases:
    //
    // (CYX signs):   ++-      +--      -+-    --+     +-+      -++
    //                 -        -        -      -       +        +
    //                / \      / \      / \    / \     / \      / \
    //               +   X    -   X    -   X  X   +   X   -    X   -
    //              / \      / \      / \        / \     / \      / \
    //             C   Y    C   Y    Y   C      Y   C   C   Y    Y   C
    //

    // First, let's determine the effective sign of each term in the original
    // expression
    auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
      bool leaf_negated = !is_child_symmetric && is_right_leaf;
      bool child_negated = !is_symmetric && (op_child == right_child);
      return leaf_negated != child_negated;
    };
    const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
    const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
    bool neg_c = !is_symmetric && (const_child == right_child);
    bool neg_x = is_leaf_negated(left_leaf_is_constant);
    bool neg_y = is_leaf_negated(!left_leaf_is_constant);
    // Rewrite the parent node.
    node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
    node->set_input(0, neg_x ? input_op : input_x);
    node->set_input(1, neg_x ? input_x : input_op);
    // Rewrite the child node.
    op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
    op_child->set_input(0, neg_c ? input_y : input_c);
    op_child->set_input(1, neg_c ? input_c : input_y);
  }
  return true;
}