in tensorflow_fold/loom/weaver.cc [381:441]
tensor_idx_t Weaver::MakeConstantSerialized(
tensor_idx_t ts_idx, const string &tensor_bytes) {
if (!FastBoundsCheck(ts_idx, num_type_shapes_)) {
error_string_ = StrCat("Invalid TypeShape ID: ", ts_idx);
return -1;
}
tensor_idx_t dt_size = tensorflow::DataTypeSize(type_shapes_[ts_idx].dtype);
if (dt_size != 0 && (dt_size * type_shapes_[ts_idx].shape.num_elements() !=
tensor_bytes.size())) {
// Note: we only bother with the size check if dt_size != 0, because when
// dt_size is zero, that means we're dealing with DataType without a fixed
// size (like string.)
error_string_ = StrCat("Invalid serialized tensor passed in; has ",
tensor_bytes.size(), " bytes, expected: ",
dt_size * type_shapes_[ts_idx].shape.num_elements());
return -1;
}
const auto &ts_metadata = metadata_.type_shape_metadata(ts_idx);
if (ts_metadata.is_batch_input()) {
error_string_ = StrCat(
"Cannot create a constant for a TypeShape ", ts_idx,
" which is in batch mode.");
return -1;
}
Tensor tensor(type_shapes_[ts_idx].dtype,
type_shapes_[ts_idx].shape);
switch (type_shapes_[ts_idx].dtype) {
#define HANDLE_CASE(_tensor_type_) \
case _tensor_type_: \
memcpy(tensor.flat<EnumToDataType<_tensor_type_>::Type>().data(), \
tensor_bytes.data(), tensor_bytes.size()); \
break;
HANDLE_CASE(DT_FLOAT);
HANDLE_CASE(DT_DOUBLE);
HANDLE_CASE(DT_INT32);
HANDLE_CASE(DT_UINT16);
HANDLE_CASE(DT_UINT8);
HANDLE_CASE(DT_INT16);
HANDLE_CASE(DT_INT8);
// HANDLE_CASE(DT_STRING); // String isn't supported.
HANDLE_CASE(DT_COMPLEX64);
HANDLE_CASE(DT_COMPLEX128);
HANDLE_CASE(DT_INT64);
HANDLE_CASE(DT_BOOL);
HANDLE_CASE(DT_QINT8);
HANDLE_CASE(DT_QUINT8);
HANDLE_CASE(DT_QINT16);
HANDLE_CASE(DT_QUINT16);
HANDLE_CASE(DT_QINT32);
HANDLE_CASE(DT_BFLOAT16);
HANDLE_CASE(DT_HALF);
#undef HANDLE_CASE
default:
LOG(FATAL) << "Weaver.MakeConstantSerialized does not support tensors "
<< "of type " << DataType_Name(type_shapes_[ts_idx].dtype);
}
return MakeConstant(ts_idx, tensor);
}