static Status ValuesFromConstNode()

in ngraph_bridge/ngraph_builder.cc [365:472]


static Status ValuesFromConstNode(const NodeDef& node,
                                  TensorShapeProto* const_tensor_shape,
                                  std::vector<VecT>* values) {
  if (node.op() != "Const") {
    return errors::InvalidArgument("Node not a Const");
  }

  if (node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
    std::stringstream ss;
    ss << "Invalid data type defined for Const. Defined: "
       << node.attr().at("dtype").type();
    return errors::InvalidArgument(ss.str());
  }

  // TensorProto represents the content of the tensor in either <type>_val or
  // tensor_content.
  const TensorProto& tensor = node.attr().at("value").tensor();
  typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
      checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));

  const TensorShapeProto& shape = tensor.tensor_shape();
  *const_tensor_shape = shape;
  if (!tensor_values->empty() && tensor.has_tensor_shape()) {
    // When tensor_shape is set, theoretically the representation of the data
    // could be compressed. So, before copying values to the returned vector,
    // make sure no compression happens.
    if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
      values->insert(values->end(), tensor_values->begin(),
                     tensor_values->end());
      return Status::OK();
    }
  }

  const auto tensor_content_size = tensor.tensor_content().size();
  CHECK_EQ(0, tensor_content_size % sizeof(VecT))
      << " tensor_content_size (" << tensor_content_size
      << ") is not a multiple of " << sizeof(VecT);

  // If tensor_content_size is zero, we'll have to take the values from
  // int_val, float_val, etc.
  if (tensor_content_size == 0) {
    int64 n_elements = 1;
    for (auto i = 0; i < shape.dim_size(); i++) {
      if (shape.dim(i).size() < 0) {
        return errors::InvalidArgument(
            "Const node has empty tensor and an unknown dimension size");
      }
      n_elements *= shape.dim(i).size();
    }
    values->resize(n_elements);

    auto val_lastsaved = (T)0;  // cast

    for (auto i = 0; i < n_elements; i++) {
      auto& tensor = node.attr().at("value").tensor();
      auto dt = node.attr().at("dtype").type();
      int64 val_size = 0;
      auto val_i = (T)0;  // cast
      switch (dt) {
        // TODO(amprocte/NGRAPH-2502): there are more element types to support
        // here
        case DT_INT32:
          val_size = tensor.int_val_size();
          if (val_size > 0) val_i = tensor.int_val()[i];
          break;
        case DT_INT64:
          val_size = tensor.int64_val_size();
          if (val_size > 0) val_i = tensor.int64_val()[i];
          break;
        case DT_FLOAT:
          val_size = tensor.float_val_size();
          if (val_size > 0) val_i = tensor.float_val()[i];
          break;
        case DT_BOOL:
          val_size = tensor.bool_val_size();
          if (val_size > 0) val_i = tensor.bool_val()[i];
          break;
        case DT_DOUBLE:
          val_size = tensor.double_val_size();
          if (val_size > 0) val_i = tensor.double_val()[i];
          break;
        default:
          NGRAPH_VLOG(0)
              << "Const node has empty tensor and we don't know how to "
                 "handle this element type";
          NGRAPH_VLOG(0) << node.DebugString();
          NGRAPH_VLOG(0) << shape.DebugString();
          return errors::Unimplemented("Encountered unknown element type ",
                                       DataType_Name(dt),
                                       " on an empty tensor");
      }
      if (val_size == 0) {
        return errors::InvalidArgument("Empty values vector");
      } else if (i < val_size) {
        (*values)[i] = val_i;
        val_lastsaved = val_i;
      } else {
        (*values)[i] = val_lastsaved;
      }
    }
  } else {
    values->resize(tensor_content_size / sizeof(VecT));
    port::CopyToArray(tensor.tensor_content(),
                      reinterpret_cast<char*>(values->data()));
  }

  return Status::OK();
}