xla::PrimitiveType GetDevicePrimitiveType()

in Sources/x10/xla_tensor/tensor_util.cpp [784:820]


xla::PrimitiveType GetDevicePrimitiveType(xla::PrimitiveType type,
                                          const Device* device) {
  Device xla_device = GetDeviceOrCurrent(device);
  switch (type) {
    case xla::PrimitiveType::F64:
      if (UseF16()) {
        return xla::PrimitiveType::F16;
      }
      if (UseBF16()) {
        return xla::PrimitiveType::BF16;
      }
      return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::F64
                                                   : xla::PrimitiveType::F32;
    case xla::PrimitiveType::F32:
      // When S4TF will support native BF16 type, the global configuration can
      // be replaced (or augmented) with the proper mapping.
      if (UseF16()) {
        return xla::PrimitiveType::F16;
      }
      return UseBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32;
    case xla::PrimitiveType::U16:
      return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::U16
                                                   : xla::PrimitiveType::U32;
    case xla::PrimitiveType::S16:
      return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::S16
                                                   : xla::PrimitiveType::S32;
    case xla::PrimitiveType::S64:
      return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
    case xla::PrimitiveType::U64:
      return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
    case xla::PrimitiveType::C128:
      return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::C128
                                                   : xla::PrimitiveType::C64;
    default:
      return type;
  }
}