in lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc [802:943]
Value MaterializeZeta(ConversionPatternRewriter &rewriter, Location loc,
ValueRange args) {
assert(args.size() == 2);
Value x = args[0];
Value q = args[1];
static const std::array<double, 12> kZetaCoeffs{
-7.1661652561756670113e18,
1.8152105401943546773e17,
-4.5979787224074726105e15,
1.1646782814350067249e14,
-2.950130727918164224e12,
7.47242496e10,
-1.8924375803183791606e9,
47900160.0,
-1209600.0,
30240.0,
-720.0,
12.0,
};
// For speed we'll always use 9 iterations for the initial series estimate,
// and a 12 term expansion for the Euler-Maclaurin formula.
Value a = q;
Value zero = chlo::getConstantLike(rewriter, loc, 0.0, a);
Value neg_power = zero;
Value neg_x = rewriter.create<mhlo::NegOp>(loc, x);
Value initial_sum = rewriter.create<mhlo::PowOp>(loc, q, neg_x);
Value one = chlo::getConstantLike(rewriter, loc, 1.0, a);
for (int i = 0; i < 9; ++i) {
a = rewriter.create<mhlo::AddOp>(loc, a, one);
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
initial_sum = rewriter.create<mhlo::AddOp>(loc, initial_sum, neg_power);
}
a = rewriter.create<mhlo::AddOp>(loc, a, one);
neg_power = rewriter.create<mhlo::PowOp>(loc, a, neg_x);
Value one_like_x = chlo::getConstantLike(rewriter, loc, 1.0, x);
Value x_minus_one = rewriter.create<mhlo::SubOp>(loc, x, one_like_x);
Value neg_power_mul_a = rewriter.create<mhlo::MulOp>(loc, neg_power, a);
Value neg_power_mul_a_div_x_minus_one =
rewriter.create<mhlo::DivOp>(loc, neg_power_mul_a, x_minus_one);
Value s = rewriter.create<mhlo::AddOp>(loc, initial_sum,
neg_power_mul_a_div_x_minus_one);
Value a_inverse_square = rewriter.create<mhlo::DivOp>(
loc, one, rewriter.create<mhlo::MulOp>(loc, a, a));
Value horner_sum = zero;
Value factor = one;
// Use Horner's rule for this.
// Note this differs from Cephes which does a 'naive' polynomial evaluation.
// Using Horner's rule allows to avoid some NaN's and Infs from happening,
// resulting in more numerically stable code.
for (int i = 0; i < 11; ++i) {
Value factor_lhs = rewriter.create<mhlo::SubOp>(
loc, x, chlo::getConstantLike(rewriter, loc, 22 - 2 * i, x));
Value factor_rhs = rewriter.create<mhlo::SubOp>(
loc, x, chlo::getConstantLike(rewriter, loc, 21 - 2 * i, x));
factor = rewriter.create<mhlo::MulOp>(loc, factor_lhs, factor_rhs);
horner_sum = rewriter.create<mhlo::MulOp>(
loc, factor,
rewriter.create<mhlo::MulOp>(
loc, a_inverse_square,
rewriter.create<mhlo::AddOp>(
loc, horner_sum,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[i], a))));
}
Value zero_point_five_like_neg_power =
chlo::getConstantLike(rewriter, loc, .5, neg_power);
Value x_div_a = rewriter.create<mhlo::DivOp>(loc, x, a);
s = rewriter.create<mhlo::AddOp>(
loc, s,
rewriter.create<mhlo::MulOp>(
loc, neg_power,
rewriter.create<mhlo::AddOp>(
loc, zero_point_five_like_neg_power,
rewriter.create<mhlo::MulOp>(
loc, x_div_a,
rewriter.create<mhlo::AddOp>(
loc,
chlo::getConstantLike(rewriter, loc, 1. / kZetaCoeffs[11],
a),
horner_sum)))));
// Use the initial zeta sum without the correction term coming
// from Euler-Maclaurin if it is accurate enough.
const StringAttr kLT = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LT));
Value abs_neg_power = rewriter.create<mhlo::AbsOp>(loc, neg_power);
Value abs_initial_sum = rewriter.create<mhlo::AbsOp>(loc, initial_sum);
Value output = rewriter.create<mhlo::SelectOp>(
loc,
rewriter.create<mhlo::CompareOp>(
loc, abs_neg_power,
rewriter.create<mhlo::MulOp>(
loc, abs_initial_sum,
chlo::getConstantLikeSmallestFiniteValue(rewriter, loc, a)),
kLT),
initial_sum, s);
// Function is not defined for x < 1.
Value nan = chlo::getConstantLike(
rewriter, loc, std::numeric_limits<double>::quiet_NaN(), x);
output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one_like_x, kLT), nan,
output);
// For q <= 0, x must be an integer.
const StringAttr kLE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::LE));
const StringAttr kNE = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::NE));
Value q_le_zero = rewriter.create<mhlo::CompareOp>(loc, q, zero, kLE);
Value x_not_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kNE);
Value x_domain_error =
rewriter.create<mhlo::AndOp>(loc, q_le_zero, x_not_int);
output = rewriter.create<mhlo::SelectOp>(loc, x_domain_error, nan, output);
// For all integer q <= 0, zeta has a pole. The limit is only defined as
// +inf if x is and even integer.
const StringAttr kEQ = rewriter.getStringAttr(
mhlo::stringifyComparisonDirection(mhlo::ComparisonDirection::EQ));
Value inf = chlo::getConstantLike(rewriter, loc,
std::numeric_limits<double>::infinity(), x);
Value q_is_int = rewriter.create<mhlo::CompareOp>(
loc, q, rewriter.create<mhlo::FloorOp>(loc, q), kEQ);
Value at_pole = rewriter.create<mhlo::AndOp>(loc, q_le_zero, q_is_int);
Value two = chlo::getConstantLike(rewriter, loc, 2.0, x);
Value x_is_int = rewriter.create<mhlo::CompareOp>(
loc, x, rewriter.create<mhlo::FloorOp>(loc, x), kEQ);
Value x_is_even = rewriter.create<mhlo::CompareOp>(
loc, rewriter.create<mhlo::RemOp>(loc, x, two), zero, kEQ);
Value x_is_even_int = rewriter.create<mhlo::AndOp>(loc, x_is_int, x_is_even);
output = rewriter.create<mhlo::SelectOp>(
loc, at_pole,
rewriter.create<mhlo::SelectOp>(loc, x_is_even_int, inf, nan), output);
// For x = 1, this is the harmonic series and diverges.
output = rewriter.create<mhlo::SelectOp>(
loc, rewriter.create<mhlo::CompareOp>(loc, x, one, kEQ), inf, output);
return output;
}