in torch_xla/csrc/random.cpp [148:199]
xla::XlaOp RngNormal(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp mean,
xla::XlaOp std) {
xla::XlaOp rng_seed = MakeSeed(seed);
xla::Shape rng_shape = MakeRngShape(shape);
xla::XlaOp initial_state =
xla::Zero(rng_seed.builder(), xla::PrimitiveType::U64);
switch (shape.element_type()) {
case xla::PrimitiveType::F16:
case xla::PrimitiveType::BF16: {
xla::XlaOp f32_mean = MaybeConvertTo(mean, xla::PrimitiveType::F32);
xla::XlaOp f32_std = MaybeConvertTo(std, xla::PrimitiveType::F32);
xla::XlaOp rng =
xla::NormalFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), rng_shape)
.value;
return xla::ConvertElementType(f32_mean + rng * f32_std,
shape.element_type());
}
case xla::PrimitiveType::F32:
case xla::PrimitiveType::F64: {
xla::XlaOp rng =
xla::NormalFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), rng_shape)
.value;
return XlaHelpers::PromotedAdd(mean, XlaHelpers::PromotedMul(rng, std));
}
case xla::PrimitiveType::C64:
case xla::PrimitiveType::C128: {
xla::XlaOp k_seed = XlaHelpers::ScalarValue<xla::uint64>(
17, XlaHelpers::TypeOfXlaOp(rng_seed), rng_seed.builder());
xla::XlaOp rng_real =
xla::NormalFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), rng_shape)
.value;
xla::XlaOp rng_imag =
xla::NormalFloatingPointDistribution(rng_seed * k_seed, initial_state,
GetBitGenerator(), rng_shape)
.value;
xla::XlaOp rng = xla::Complex(rng_real, rng_imag);
// Variance for normal distribution of the real and imaginary values is
// half of the input variance.
xla::XlaOp sqrtTwo = XlaHelpers::ScalarValue(
std::sqrt(2), XlaHelpers::TypeOfXlaOp(std), rng_seed.builder());
return XlaHelpers::PromotedAdd(
mean, XlaHelpers::PromotedMul(rng, std / sqrtTwo));
}
default:
XLA_ERROR() << "RngNormal not implemented for type "
<< xla::primitive_util::LowercasePrimitiveTypeName(
shape.element_type());
}
}