DLDataType RelayVMModel::GetInputDLDataType()

in src/dlr_relayvm.cc [232:279]


DLDataType RelayVMModel::GetInputDLDataType(int index) {
  auto input_type = input_types_[index];
  DLDataType dtype;
  dtype.lanes = 1;
  if (input_type == "bool") {
    dtype.code = kDLUInt;
    dtype.bits = 1;
  } else if (input_type == "uint8") {
    dtype.code = kDLUInt;
    dtype.bits = 8;
  } else if (input_type == "int8") {
    dtype.code = kDLInt;
    dtype.bits = 8;
  } else if (input_type == "uint16") {
    dtype.code = kDLUInt;
    dtype.bits = 16;
  } else if (input_type == "int16") {
    dtype.code = kDLInt;
    dtype.bits = 16;
  } else if (input_type == "uint32") {
    dtype.code = kDLUInt;
    dtype.bits = 32;
  } else if (input_type == "int32") {
    dtype.code = kDLInt;
    dtype.bits = 32;
  } else if (input_type == "uint64") {
    dtype.code = kDLUInt;
    dtype.bits = 64;
  } else if (input_type == "int64") {
    dtype.code = kDLInt;
    dtype.bits = 64;
  } else if (input_type == "float16") {
    dtype.code = kDLFloat;
    dtype.bits = 16;
  } else if (input_type == "bfloat16") {
    dtype.code = kDLBfloat;
    dtype.bits = 16;
  } else if (input_type == "float32") {
    dtype.code = kDLFloat;
    dtype.bits = 32;
  } else if (input_type == "float64") {
    dtype.code = kDLBfloat;
    dtype.bits = 64;
  } else {
    throw dmlc::Error(std::string("Unknown input dtype: ") + input_type);
  }
  return dtype;
}