xla::XlaOp RngUniform()

in torch_xla/csrc/random.cpp [95:146]


xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape,
                      xla::XlaOp minval, xla::XlaOp maxval) {
  xla::XlaOp rng_seed = MakeSeed(seed);
  xla::Shape rng_shape = MakeRngShape(shape);
  xla::XlaOp rng_minval = MakeUniformBoundaryValue(minval);
  xla::XlaOp rng_maxval = MakeUniformBoundaryValue(maxval);
  xla::XlaOp initial_state =
      xla::Zero(rng_seed.builder(), xla::PrimitiveType::U64);
  switch (shape.element_type()) {
    case xla::PrimitiveType::F16:
    case xla::PrimitiveType::BF16: {
      xla::XlaOp rng = xla::UniformFloatingPointDistribution(
                           rng_seed, initial_state, GetBitGenerator(),
                           rng_minval, rng_maxval, rng_shape)
                           .value;
      return xla::ConvertElementType(rng, shape.element_type());
    }
    case xla::PrimitiveType::F32:
    case xla::PrimitiveType::F64:
      return xla::UniformFloatingPointDistribution(
                 rng_seed, initial_state, GetBitGenerator(), rng_minval,
                 rng_maxval, rng_shape)
          .value;
    case xla::PrimitiveType::C64:
    case xla::PrimitiveType::C128: {
      xla::XlaOp k_seed = XlaHelpers::ScalarValue<xla::uint64>(
          17, XlaHelpers::TypeOfXlaOp(rng_seed), rng_seed.builder());
      xla::XlaOp rng_real = xla::UniformFloatingPointDistribution(
                                rng_seed, initial_state, GetBitGenerator(),
                                rng_minval, rng_maxval, rng_shape)
                                .value;
      xla::XlaOp rng_imag =
          xla::UniformFloatingPointDistribution(
              rng_seed * k_seed, initial_state, GetBitGenerator(), rng_minval,
              rng_maxval, rng_shape)
              .value;
      return xla::Complex(rng_real, rng_imag);
    }
    case xla::PrimitiveType::S32:
    case xla::PrimitiveType::U32:
    case xla::PrimitiveType::S64:
    case xla::PrimitiveType::U64:
      return xla::UniformIntDistribution(rng_seed, initial_state,
                                         GetBitGenerator(), rng_minval,
                                         rng_maxval, rng_shape)
          .value;
    default:
      XLA_ERROR() << "RngUniform not implemented for type "
                  << xla::primitive_util::LowercasePrimitiveTypeName(
                         shape.element_type());
  }
}