Value MaterializeZeta()

in lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc [802:943]


Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
                      ValueRange args) {
  assert(args.size() == 2);
  Value x = args[0];
  Value q = args[1];
  static const std::array<double, 12> kZetaCoeffs{
      -7.1661652561756670113e18,
      1.8152105401943546773e17,
      -4.5979787224074726105e15,
      1.1646782814350067249e14,
      -2.950130727918164224e12,
      7.47242496e10,
      -1.8924375803183791606e9,
      47900160.0,
      -1209600.0,
      30240.0,
      -720.0,
      12.0,
  };

  // For speed we'll always use 9 iterations for the initial series estimate,
  // and a 12 term expansion for the Euler-Maclaurin formula.
  Value a = q;
  Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
  Value neg_power = zero;
  Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
  Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
  Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
  for (int i = 0; i < 9; ++i) {
    a = rewriter.create<mhlo::AddOp>(loc, a, one);
    neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
    initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
  }
  a = rewriter.create<mhlo::AddOp>(loc, a, one);
  neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
  Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
  Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
  Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
  Value neg_power_mul_a_div_x_minus_one =
      rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
  Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
                                         neg_power_mul_a_div_x_minus_one);
  Value a_inverse_square = rewriter.create<mhlo::DivOp>(
      loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));

  Value horner_sum = zero;
  Value factor = one;
  // Use Horner's rule for this.
  // Note this differs from Cephes which does a 'naive' polynomial evaluation.
  // Using Horner's rule allows to avoid some NaN's and Infs from happening,
  // resulting in more numerically stable code.
  for (int i = 0; i < 11; ++i) {
    Value factor_lhs = rewriter.create<mhlo::SubOp>(
        loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
    Value factor_rhs = rewriter.create<mhlo::SubOp>(
        loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
    factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
    horner_sum = rewriter.create<mhlo::MulOp>(
        loc, factor,
        rewriter.create<mhlo::MulOp>(
            loc, a_inverse_square,
            rewriter.create<mhlo::AddOp>(
                loc, horner_sum,
                chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
  }
  Value zero_point_five_like_neg_power =
      chlo::getConstantLike(rewriter, loc, .5, neg_power);
  Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
  s = rewriter.create<mhlo::AddOp>(
      loc, s,
      rewriter.create<mhlo::MulOp>(
          loc, neg_power,
          rewriter.create<mhlo::AddOp>(
              loc, zero_point_five_like_neg_power,
              rewriter.create<mhlo::MulOp>(
                  loc, x_div_a,
                  rewriter.create<mhlo::AddOp>(
                      loc,
                      chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
                                            a),
                      horner_sum)))));

  // Use the initial zeta sum without the correction term coming
  // from Euler-Maclaurin if it is accurate enough.
  const StringAttr kLT = rewriter.getStringAttr(
      mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
  Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
  Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
  Value output = rewriter.create<mhlo::SelectOp>(
      loc,
      rewriter.create<mhlo::CompareOp>(
          loc, abs_neg_power,
          rewriter.create<mhlo::MulOp>(
              loc, abs_initial_sum,
              chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
          kLT),
      initial_sum, s);

  // Function is not defined for x < 1.
  Value nan = chlo::getConstantLike(
      rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
  output = rewriter.create<mhlo::SelectOp>(
      loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
      output);

  // For q <= 0, x must be an integer.
  const StringAttr kLE = rewriter.getStringAttr(
      mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
  const StringAttr kNE = rewriter.getStringAttr(
      mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
  Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
  Value x_not_int = rewriter.create<mhlo::CompareOp>(
      loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
  Value x_domain_error =
      rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
  output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);

  // For all integer q <= 0, zeta has a pole. The limit is only defined as
  // +inf if x is and even integer.
  const StringAttr kEQ = rewriter.getStringAttr(
      mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
  Value inf = chlo::getConstantLike(rewriter, loc,
                                    std::numeric_limits<double>::infinity(), x);
  Value q_is_int = rewriter.create<mhlo::CompareOp>(
      loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
  Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
  Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
  Value x_is_int = rewriter.create<mhlo::CompareOp>(
      loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
  Value x_is_even = rewriter.create<mhlo::CompareOp>(
      loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
  Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
  output = rewriter.create<mhlo::SelectOp>(
      loc, at_pole,
      rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);

  // For x = 1, this is the harmonic series and diverges.
  output = rewriter.create<mhlo::SelectOp>(
      loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);

  return output;
}