tensor_idx_t Weaver::MakeConstantSerialized()

in tensorflow_fold/loom/weaver.cc [381:441]


tensor_idx_t Weaver::MakeConstantSerialized(
    tensor_idx_t ts_idx, const string &tensor_bytes) {
  if (!FastBoundsCheck(ts_idx, num_type_shapes_)) {
    error_string_ = StrCat("Invalid TypeShape ID: ", ts_idx);
    return -1;
  }
  tensor_idx_t dt_size = tensorflow::DataTypeSize(type_shapes_[ts_idx].dtype);
  if (dt_size != 0 && (dt_size * type_shapes_[ts_idx].shape.num_elements() !=
                       tensor_bytes.size())) {
    // Note: we only bother with the size check if dt_size != 0, because when
    // dt_size is zero, that means we're dealing with DataType without a fixed
    // size (like string.)
    error_string_ = StrCat("Invalid serialized tensor passed in; has ",
                           tensor_bytes.size(), " bytes, expected: ",
                           dt_size * type_shapes_[ts_idx].shape.num_elements());
    return -1;
  }

  const auto &ts_metadata = metadata_.type_shape_metadata(ts_idx);
  if (ts_metadata.is_batch_input()) {
    error_string_ = StrCat(
        "Cannot create a constant for a TypeShape ", ts_idx,
        " which is in batch mode.");
    return -1;
  }

  Tensor tensor(type_shapes_[ts_idx].dtype,
                type_shapes_[ts_idx].shape);
  switch (type_shapes_[ts_idx].dtype) {
#define HANDLE_CASE(_tensor_type_) \
    case _tensor_type_: \
      memcpy(tensor.flat<EnumToDataType<_tensor_type_>::Type>().data(), \
             tensor_bytes.data(), tensor_bytes.size()); \
      break;
    HANDLE_CASE(DT_FLOAT);
    HANDLE_CASE(DT_DOUBLE);
    HANDLE_CASE(DT_INT32);
    HANDLE_CASE(DT_UINT16);
    HANDLE_CASE(DT_UINT8);
    HANDLE_CASE(DT_INT16);
    HANDLE_CASE(DT_INT8);
    // HANDLE_CASE(DT_STRING);  // String isn't supported.
    HANDLE_CASE(DT_COMPLEX64);
    HANDLE_CASE(DT_COMPLEX128);
    HANDLE_CASE(DT_INT64);
    HANDLE_CASE(DT_BOOL);
    HANDLE_CASE(DT_QINT8);
    HANDLE_CASE(DT_QUINT8);
    HANDLE_CASE(DT_QINT16);
    HANDLE_CASE(DT_QUINT16);
    HANDLE_CASE(DT_QINT32);
    HANDLE_CASE(DT_BFLOAT16);
    HANDLE_CASE(DT_HALF);
#undef HANDLE_CASE
    default:
      LOG(FATAL) << "Weaver.MakeConstantSerialized does not support tensors "
                 << "of type " << DataType_Name(type_shapes_[ts_idx].dtype);
  }

  return MakeConstant(ts_idx, tensor);
}