in torch_xla/csrc/ops/ops.cpp [484:551]
NodePtr ARange(const at::Scalar& start, const at::Scalar& end,
const at::Scalar& step, at::ScalarType scalar_type) {
xla::PrimitiveType type = MakeXlaPrimitiveType(scalar_type,
/*device=*/nullptr);
XLA_CHECK_NE(step.toDouble(), 0.0);
XLA_CHECK(!std::isnan(start.toDouble()) && !std::isnan(end.toDouble()))
<< "unsupported range: " << start.toDouble() << " -> " << end.toDouble();
XLA_CHECK((start.toDouble() <= end.toDouble() && step.toDouble() > 0.0) ||
(start.toDouble() >= end.toDouble() && step.toDouble() < 0.0));
xla::Literal values;
switch (type) {
case xla::PrimitiveType::BF16:
values = XlaHelpers::Range<tensorflow::bfloat16>(
static_cast<tensorflow::bfloat16>(start.toFloat()),
static_cast<tensorflow::bfloat16>(end.toFloat()),
static_cast<tensorflow::bfloat16>(step.toFloat()));
break;
case xla::PrimitiveType::F16:
values =
XlaHelpers::Range<xla::half>(static_cast<xla::half>(start.toHalf()),
static_cast<xla::half>(end.toHalf()),
static_cast<xla::half>(step.toHalf()));
break;
case xla::PrimitiveType::F32:
values = XlaHelpers::Range<float>(start.toFloat(), end.toFloat(),
step.toFloat());
break;
case xla::PrimitiveType::F64:
values = XlaHelpers::Range<double>(start.toDouble(), end.toDouble(),
step.toDouble());
break;
case xla::PrimitiveType::U8:
values = XlaHelpers::Range<xla::uint8>(start.toByte(), end.toByte(),
step.toByte());
break;
case xla::PrimitiveType::S8:
values = XlaHelpers::Range<xla::int8>(start.toChar(), end.toChar(),
step.toChar());
break;
case xla::PrimitiveType::S16:
values = XlaHelpers::Range<xla::int16>(start.toShort(), end.toShort(),
step.toShort());
break;
case xla::PrimitiveType::U16:
values = XlaHelpers::Range<xla::uint16>(start.toInt(), end.toInt(),
step.toInt());
break;
case xla::PrimitiveType::S32:
values = XlaHelpers::Range<xla::int32>(start.toInt(), end.toInt(),
step.toInt());
break;
case xla::PrimitiveType::U32:
values = XlaHelpers::Range<xla::uint32>(start.toLong(), end.toLong(),
step.toLong());
break;
case xla::PrimitiveType::S64:
values = XlaHelpers::Range<xla::int64_t>(start.toLong(), end.toLong(),
step.toLong());
break;
case xla::PrimitiveType::U64:
values = XlaHelpers::Range<xla::uint64>(start.toLong(), end.toLong(),
step.toLong());
break;
default:
XLA_ERROR() << "XLA type not supported: " << type;
}
return MakeNode<Constant>(std::move(values));
}