Status ConstantFolding::SimplifyArithmeticOperations()

in tensorflow/tensorflow/core/grappler/optimizers/constant_folding.cc [2669:2793]


Status ConstantFolding::SimplifyArithmeticOperations(
    const GraphProperties& properties, bool use_shape_info,
    GraphDef* optimized_graph, NodeDef* node) {
  const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
  const bool is_matmul = IsAnyMatMul(*node);
  const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
  const bool is_sub = IsSub(*node);
  const bool is_any_div = IsAnyDiv(*node);
  // Simplify arithmetic operations with ones or zeros.
  if (use_shape_info &&
      (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
      properties.HasInputProperties(node->name()) &&
      properties.HasOutputProperties(node->name())) {
    const NodeDef* x = node_map_->GetNode(node->input(0));
    const NodeDef* y = node_map_->GetNode(node->input(1));
    if (x == nullptr || y == nullptr) {
      return errors::InvalidArgument("Invalid inputs to node: ",
                                     node->DebugString());
    }
    const TensorShapeProto& output_shape =
        properties.GetOutputProperties(node->name())[0].shape();

    // Simplify element-wise multiplication by ones or addition/subtraction
    // of zeros.
    const TensorShapeProto& y_shape =
        properties.GetInputProperties(node->name())[1].shape();
    const TensorShapeProto& x_shape =
        properties.GetInputProperties(node->name())[0].shape();
    const bool y_matches_output_shape =
        ShapesSymbolicallyEqual(output_shape, y_shape);
    const bool x_matches_output_shape =
        ShapesSymbolicallyEqual(output_shape, x_shape);

    const bool x_is_zero = IsZeros(*x);
    const bool x_is_one = x_is_zero ? false : IsOnes(*x);
    if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
      // 1 * y = y or 0 + y = y.
      if (y_matches_output_shape) {
        ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
      } else if (x_matches_output_shape) {
        ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
                                              optimized_graph);
      }
      return Status::OK();
    }

    if (y_matches_output_shape && (is_sub && x_is_zero)) {
      // Replace 0 - y with Neg(y).
      ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
      return Status::OK();
    }

    // Replace 1 / y with Reciprocal op.
    if (y_matches_output_shape && is_any_div && x_is_one) {
      TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
      DataType type = node->attr().at("T").type();
      if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
        ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
        return Status::OK();
      }
    }

    const bool y_is_zero = IsZeros(*y);
    const bool y_is_one = y_is_zero ? false : IsOnes(*y);
    if (((is_mul || is_any_div) && y_is_one) ||
        ((is_add || is_sub) && y_is_zero)) {
      // x * 1 = x or x / 1 = x or x +/- 0 = x
      if (x_matches_output_shape) {
        ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
      } else if (y_matches_output_shape) {
        ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
                                              optimized_graph);
      }
      return Status::OK();
    }

    // x OR true = true OR y = true.
    const PartialTensorShape shp(output_shape);
    if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
      TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
          1, properties, output_shape, node, optimized_graph));
      return Status::OK();
    }

    // Simplify multiplication and matmul by zeros.
    // Also optimize zeros divided by a tensor, but only if we are in
    // aggressive mode, since we might get rid of divisions by zero.
    const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
    bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
    if ((x_is_zero || y_is_zero) &&
        (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
      if (shp.IsFullyDefined()) {
        bool is_quantized = IsQuantizedMatMul(*node);
        TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
            0, properties, output_shape, node, optimized_graph));
        if (is_quantized && graph_modified_) {
          TF_RETURN_IF_ERROR(
              AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
        }
        return Status::OK();
      }
      // Even if an input shape is only partially known, we may known that it
      // matches the output shape and thus forward or broadcast the
      // corresponding zero input.
      if ((is_mul || is_any_div) && x_is_zero) {
        if (x_matches_output_shape) {
          ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
        } else if (y_matches_output_shape) {
          ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
                                                optimized_graph);
        }
        return Status::OK();
      } else if (is_mul && y_is_zero) {
        if (y_matches_output_shape) {
          ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
        } else if (x_matches_output_shape) {
          ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
                                                optimized_graph);
        }
        return Status::OK();
      }
    }
  }
  return Status::OK();
}