in torch_xla/csrc/tensor_util.cpp [647:680]
at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
at::ScalarType dest_element_type) {
switch (dest_element_type) {
case at::ScalarType::Bool:
return XlaLiteralToTensor<SType, bool>(literal, dest_element_type);
case at::ScalarType::Byte:
return XlaLiteralToTensor<SType, uint8_t>(literal, dest_element_type);
case at::ScalarType::Char:
return XlaLiteralToTensor<SType, int8_t>(literal, dest_element_type);
case at::ScalarType::Short:
return XlaLiteralToTensor<SType, int16_t>(literal, dest_element_type);
case at::ScalarType::Int:
return XlaLiteralToTensor<SType, int32_t>(literal, dest_element_type);
case at::ScalarType::Long:
return XlaLiteralToTensor<SType, int64_t>(literal, dest_element_type);
case at::ScalarType::Float:
return XlaLiteralToTensor<SType, float>(literal, dest_element_type);
case at::ScalarType::Double:
return XlaLiteralToTensor<SType, double>(literal, dest_element_type);
case at::ScalarType::BFloat16:
return XlaLiteralToTensor<SType, at::BFloat16>(literal,
dest_element_type);
case at::ScalarType::Half:
return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
case at::ScalarType::ComplexFloat:
return XlaLiteralToTensor<SType, c10::complex<float>>(literal,
dest_element_type);
case at::ScalarType::ComplexDouble:
return XlaLiteralToTensor<SType, c10::complex<double>>(literal,
dest_element_type);
default:
XLA_ERROR() << "Unsupported scalar type: " << dest_element_type;
}
}