def __new__()

in adanet/subnetwork/report.py [0:0]


  def __new__(cls, hparams, attributes, metrics):

    def _is_scalar(tensor):
      """Returns True iff tensor is scalar."""
      return tensor.shape.ndims == 0

    def _is_accepted_dtype(tensor):
      """Returns True iff tensor has the dtype we can handle."""
      return tensor.dtype.base_dtype in (tf.bool, tf.int32, tf.float32,
                                         tf.float64, tf.string)

    # Validate hparams
    for key, value in hparams.items():
      if not isinstance(value, (bool, int, float, six.string_types)):
        raise ValueError(
            "hparam '{}' refers to invalid value {}, type {}. type must be "
            "python primitive int, float, bool, or string.".format(
                key, value, type(value)))

    # Validate attributes
    for key, value in attributes.items():
      if not isinstance(value, tf.Tensor):
        raise ValueError("attribute '{}' refers to invalid value: {}, type: {}."
                         "type must be Tensor.".format(key, value, type(value)))

      if not (_is_scalar(value) and _is_accepted_dtype(value)):
        raise ValueError(
            "attribute '{}' refers to invalid tensor {}. Shape: {}".format(
                key, value, value.get_shape()))

    # Validate metrics
    metrics_copy = {}
    for key, value in metrics.items():
      value = tf_compat.metric_op(value)
      if not isinstance(value, tuple):
        raise ValueError(
            "metric '{}' has invalid type {}. Must be a tuple.".format(
                key, type(value)))

      if len(value) < 2:
        raise ValueError(
            "metric tuple '{}' has fewer than 2 elements".format(key))

      if not isinstance(value[0], (tf.Tensor, tf.Variable)):
        raise ValueError(
            "First element of metric tuple '{}' has value {} and type {}. "
            "Must be a Tensor or Variable.".format(key, value[0],
                                                   type(value[0])))

      if not _is_accepted_dtype(value[0]):
        raise ValueError(
            "First element of metric '{}' refers to Tensor of the wrong "
            "dtype {}. Must be one of tf.bool, tf.int32, tf.float32, "
            "tf.float64 or tf.string.".format(key, value[0].dtype))

      if not _is_scalar(value[0]):
        tf.logging.warn(
            "First element of metric '{}' refers to Tensor of rank > 0. "
            "AdaNet is currently unable to store metrics of rank > 0 -- this "
            "metric will be dropped from the report. "
            "value: {}".format(key, value[0]))
        continue

      if not isinstance(value[1], (tf.Tensor, tf.Operation, tf.Variable)):
        raise ValueError(
            "Second element of metric tuple '{}' has value {} and type {}. "
            "Must be a Tensor, Operation, or Variable.".format(
                key, value[1], type(value[1])))

      metrics_copy[key] = value

    return super(Report, cls).__new__(
        cls, hparams=hparams, attributes=attributes, metrics=metrics_copy)