in torch_xla/csrc/random.cpp [95:146]
xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval) {
xla::XlaOp rng_seed = MakeSeed(seed);
xla::Shape rng_shape = MakeRngShape(shape);
xla::XlaOp rng_minval = MakeUniformBoundaryValue(minval);
xla::XlaOp rng_maxval = MakeUniformBoundaryValue(maxval);
xla::XlaOp initial_state =
xla::Zero(rng_seed.builder(), xla::PrimitiveType::U64);
switch (shape.element_type()) {
case xla::PrimitiveType::F16:
case xla::PrimitiveType::BF16: {
xla::XlaOp rng = xla::UniformFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(),
rng_minval, rng_maxval, rng_shape)
.value;
return xla::ConvertElementType(rng, shape.element_type());
}
case xla::PrimitiveType::F32:
case xla::PrimitiveType::F64:
return xla::UniformFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(), rng_minval,
rng_maxval, rng_shape)
.value;
case xla::PrimitiveType::C64:
case xla::PrimitiveType::C128: {
xla::XlaOp k_seed = XlaHelpers::ScalarValue<xla::uint64>(
17, XlaHelpers::TypeOfXlaOp(rng_seed), rng_seed.builder());
xla::XlaOp rng_real = xla::UniformFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(),
rng_minval, rng_maxval, rng_shape)
.value;
xla::XlaOp rng_imag =
xla::UniformFloatingPointDistribution(
rng_seed * k_seed, initial_state, GetBitGenerator(), rng_minval,
rng_maxval, rng_shape)
.value;
return xla::Complex(rng_real, rng_imag);
}
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64:
return xla::UniformIntDistribution(rng_seed, initial_state,
GetBitGenerator(), rng_minval,
rng_maxval, rng_shape)
.value;
default:
XLA_ERROR() << "RngUniform not implemented for type "
<< xla::primitive_util::LowercasePrimitiveTypeName(
shape.element_type());
}
}