in lib/Base/Tensor.cpp [830:945]
Tensor Tensor::getCopyConvertedToType(ElemKind newKind) const {
assert(!isDeviceResident() && "Tensor must reside on host to access data.");
const ElemKind origKind = getElementType();
DCHECK((origKind == ElemKind::FloatTy && newKind == ElemKind::Float16Ty) ||
(origKind == ElemKind::FloatTy && newKind == ElemKind::BFloat16Ty) ||
(origKind == ElemKind::FloatTy && newKind == ElemKind::Int32ITy) ||
(origKind == ElemKind::FloatTy && newKind == ElemKind::Int64ITy) ||
(origKind == ElemKind::Float16Ty && newKind == ElemKind::FloatTy) ||
(origKind == ElemKind::BFloat16Ty && newKind == ElemKind::FloatTy) ||
(origKind == ElemKind::Int64ITy && newKind == ElemKind::Int32ITy) ||
(origKind == ElemKind::Int64ITy && newKind == ElemKind::FloatTy) ||
(origKind == ElemKind::Int32ITy && newKind == ElemKind::Int64ITy) ||
(origKind == ElemKind::Int32ITy && newKind == ElemKind::FloatTy) ||
(origKind == ElemKind::UInt8FusedQTy &&
newKind == ElemKind::UInt8FusedFP16QTy) ||
(origKind == ElemKind::UInt8FusedFP16QTy &&
newKind == ElemKind::UInt8FusedQTy) ||
(origKind == ElemKind::UInt4FusedFP16QTy &&
newKind == ElemKind::UInt8FusedQTy) ||
(origKind == ElemKind::UInt4FusedFP16QTy &&
newKind == ElemKind::UInt4FusedQTy) ||
(origKind == ElemKind::UInt4FusedQTy &&
newKind == ElemKind::UInt8FusedQTy))
<< "Conversion from " << Type::getElementName(origKind).str() << " to "
<< Type::getElementName(newKind).str() << " is not yet implemented";
if (!isQuantizedElemKind(newKind)) {
Tensor tmp(newKind, dims());
switch (newKind) {
case ElemKind::Float16Ty:
tmp.copyWithCast<float16_t, float>(this);
break;
case ElemKind::BFloat16Ty:
tmp.copyWithCast<bfloat16_t, float>(this);
break;
case ElemKind::FloatTy:
if (getElementType() == ElemKind::Int32ITy) {
tmp.copyWithCast<float, int32_t>(this);
} else if (getElementType() == ElemKind::Int64ITy) {
tmp.copyWithCast<float, int64_t>(this);
} else if (getElementType() == ElemKind::Float16Ty) {
tmp.copyWithCast<float, float16_t>(this);
} else if (getElementType() == ElemKind::BFloat16Ty) {
tmp.copyWithCast<float, bfloat16_t>(this);
} else if (getElementType() == ElemKind::FloatTy) {
tmp.copyRawFrom(this);
} else {
llvm_unreachable("Invalid conversion to FLOAT.");
}
break;
case ElemKind::Int32ITy:
if (getElementType() == ElemKind::Int64ITy) {
tmp.copyWithCast<int32_t, int64_t>(this);
} else if (getElementType() == ElemKind::FloatTy) {
tmp.copyWithCast<int32_t, float>(this);
} else {
llvm_unreachable("Invalid conversion from FLOAT.");
}
break;
case ElemKind::Int64ITy:
if (getElementType() == ElemKind::Int32ITy) {
tmp.copyWithCast<int64_t, int32_t>(this);
} else {
llvm_unreachable("Invalid conversion from FLOAT.");
}
break;
default:
llvm_unreachable("Type not supported");
}
return tmp;
}
// Handle Fused conversion.
if ((origKind == ElemKind::UInt8FusedFP16QTy ||
origKind == ElemKind::UInt4FusedFP16QTy) &&
newKind == ElemKind::UInt8FusedQTy) {
return convertToUInt8FusedQTy<float16_t>(this);
}
if (origKind == ElemKind::UInt4FusedQTy &&
newKind == ElemKind::UInt8FusedQTy) {
return convertToUInt8FusedQTy<float>(this);
}
if (origKind == ElemKind::UInt4FusedFP16QTy &&
newKind == ElemKind::UInt4FusedQTy) {
return convertToUInt4FusedQTy(this);
}
// Supports UInt8FusedQTy -> UInt8FusedFP16QTy.
DCHECK(origKind == ElemKind::UInt8FusedQTy && dims().size() == 2)
<< "UInt8FusedQTy must be 2 dimensional.";
Tensor tmp(newKind,
{dims()[0], dims()[1] - 2 * ((dim_t)sizeof(float) -
(dim_t)sizeof(float16_t))},
1.0, 0);
const size_t dstWidth = tmp.dims()[1];
auto srcH = getHandle<uint8_t>();
auto dstH = tmp.getHandle<uint8_t>();
for (dim_t i = 0, e = dims()[0]; i < e; i++) {
// Copy the scale/offset from src to dst.
float scale, offset;
std::tie(scale, offset) = srcH.getFusedScaleOffsetFromRow<float>(i);
dstH.setFusedScaleOffsetInRow<float16_t>(i, static_cast<float16_t>(scale),
static_cast<float16_t>(offset));
// Copy over the row's uint8 data from src to dst; scales and offsets were
// already copied over above.
for (dim_t j = 0, f = dstWidth - 2 * sizeof(float16_t); j < f; j++) {
dstH.at({i, j}) = srcH.at({i, j});
}
}
return tmp;
}