Tensor Tensor::getCopyConvertedToType()

in lib/Base/Tensor.cpp [830:945]


Tensor Tensor::getCopyConvertedToType(ElemKind newKind) const {
  assert(!isDeviceResident() && "Tensor must reside on host to access data.");
  const ElemKind origKind = getElementType();
  DCHECK((origKind == ElemKind::FloatTy && newKind == ElemKind::Float16Ty) ||
         (origKind == ElemKind::FloatTy && newKind == ElemKind::BFloat16Ty) ||
         (origKind == ElemKind::FloatTy && newKind == ElemKind::Int32ITy) ||
         (origKind == ElemKind::FloatTy && newKind == ElemKind::Int64ITy) ||
         (origKind == ElemKind::Float16Ty && newKind == ElemKind::FloatTy) ||
         (origKind == ElemKind::BFloat16Ty && newKind == ElemKind::FloatTy) ||
         (origKind == ElemKind::Int64ITy && newKind == ElemKind::Int32ITy) ||
         (origKind == ElemKind::Int64ITy && newKind == ElemKind::FloatTy) ||
         (origKind == ElemKind::Int32ITy && newKind == ElemKind::Int64ITy) ||
         (origKind == ElemKind::Int32ITy && newKind == ElemKind::FloatTy) ||
         (origKind == ElemKind::UInt8FusedQTy &&
          newKind == ElemKind::UInt8FusedFP16QTy) ||
         (origKind == ElemKind::UInt8FusedFP16QTy &&
          newKind == ElemKind::UInt8FusedQTy) ||
         (origKind == ElemKind::UInt4FusedFP16QTy &&
          newKind == ElemKind::UInt8FusedQTy) ||
         (origKind == ElemKind::UInt4FusedFP16QTy &&
          newKind == ElemKind::UInt4FusedQTy) ||
         (origKind == ElemKind::UInt4FusedQTy &&
          newKind == ElemKind::UInt8FusedQTy))
      << "Conversion from " << Type::getElementName(origKind).str() << " to "
      << Type::getElementName(newKind).str() << " is not yet implemented";

  if (!isQuantizedElemKind(newKind)) {
    Tensor tmp(newKind, dims());
    switch (newKind) {
    case ElemKind::Float16Ty:
      tmp.copyWithCast<float16_t, float>(this);
      break;
    case ElemKind::BFloat16Ty:
      tmp.copyWithCast<bfloat16_t, float>(this);
      break;

    case ElemKind::FloatTy:
      if (getElementType() == ElemKind::Int32ITy) {
        tmp.copyWithCast<float, int32_t>(this);
      } else if (getElementType() == ElemKind::Int64ITy) {
        tmp.copyWithCast<float, int64_t>(this);
      } else if (getElementType() == ElemKind::Float16Ty) {
        tmp.copyWithCast<float, float16_t>(this);
      } else if (getElementType() == ElemKind::BFloat16Ty) {
        tmp.copyWithCast<float, bfloat16_t>(this);
      } else if (getElementType() == ElemKind::FloatTy) {
        tmp.copyRawFrom(this);
      } else {
        llvm_unreachable("Invalid conversion to FLOAT.");
      }
      break;

    case ElemKind::Int32ITy:
      if (getElementType() == ElemKind::Int64ITy) {
        tmp.copyWithCast<int32_t, int64_t>(this);
      } else if (getElementType() == ElemKind::FloatTy) {
        tmp.copyWithCast<int32_t, float>(this);
      } else {
        llvm_unreachable("Invalid conversion from FLOAT.");
      }
      break;
    case ElemKind::Int64ITy:
      if (getElementType() == ElemKind::Int32ITy) {
        tmp.copyWithCast<int64_t, int32_t>(this);
      } else {
        llvm_unreachable("Invalid conversion from FLOAT.");
      }
      break;

    default:
      llvm_unreachable("Type not supported");
    }
    return tmp;
  }

  // Handle Fused conversion.
  if ((origKind == ElemKind::UInt8FusedFP16QTy ||
       origKind == ElemKind::UInt4FusedFP16QTy) &&
      newKind == ElemKind::UInt8FusedQTy) {
    return convertToUInt8FusedQTy<float16_t>(this);
  }
  if (origKind == ElemKind::UInt4FusedQTy &&
      newKind == ElemKind::UInt8FusedQTy) {
    return convertToUInt8FusedQTy<float>(this);
  }
  if (origKind == ElemKind::UInt4FusedFP16QTy &&
      newKind == ElemKind::UInt4FusedQTy) {
    return convertToUInt4FusedQTy(this);
  }

  // Supports UInt8FusedQTy -> UInt8FusedFP16QTy.
  DCHECK(origKind == ElemKind::UInt8FusedQTy && dims().size() == 2)
      << "UInt8FusedQTy must be 2 dimensional.";
  Tensor tmp(newKind,
             {dims()[0], dims()[1] - 2 * ((dim_t)sizeof(float) -
                                          (dim_t)sizeof(float16_t))},
             1.0, 0);

  const size_t dstWidth = tmp.dims()[1];
  auto srcH = getHandle<uint8_t>();
  auto dstH = tmp.getHandle<uint8_t>();
  for (dim_t i = 0, e = dims()[0]; i < e; i++) {
    // Copy the scale/offset from src to dst.
    float scale, offset;
    std::tie(scale, offset) = srcH.getFusedScaleOffsetFromRow<float>(i);
    dstH.setFusedScaleOffsetInRow<float16_t>(i, static_cast<float16_t>(scale),
                                             static_cast<float16_t>(offset));

    // Copy over the row's uint8 data from src to dst; scales and offsets were
    // already copied over above.
    for (dim_t j = 0, f = dstWidth - 2 * sizeof(float16_t); j < f; j++) {
      dstH.at({i, j}) = srcH.at({i, j});
    }
  }
  return tmp;
}