in adanet/subnetwork/report.py [0:0]
def __new__(cls, hparams, attributes, metrics):
def _is_scalar(tensor):
"""Returns True iff tensor is scalar."""
return tensor.shape.ndims == 0
def _is_accepted_dtype(tensor):
"""Returns True iff tensor has the dtype we can handle."""
return tensor.dtype.base_dtype in (tf.bool, tf.int32, tf.float32,
tf.float64, tf.string)
# Validate hparams
for key, value in hparams.items():
if not isinstance(value, (bool, int, float, six.string_types)):
raise ValueError(
"hparam '{}' refers to invalid value {}, type {}. type must be "
"python primitive int, float, bool, or string.".format(
key, value, type(value)))
# Validate attributes
for key, value in attributes.items():
if not isinstance(value, tf.Tensor):
raise ValueError("attribute '{}' refers to invalid value: {}, type: {}."
"type must be Tensor.".format(key, value, type(value)))
if not (_is_scalar(value) and _is_accepted_dtype(value)):
raise ValueError(
"attribute '{}' refers to invalid tensor {}. Shape: {}".format(
key, value, value.get_shape()))
# Validate metrics
metrics_copy = {}
for key, value in metrics.items():
value = tf_compat.metric_op(value)
if not isinstance(value, tuple):
raise ValueError(
"metric '{}' has invalid type {}. Must be a tuple.".format(
key, type(value)))
if len(value) < 2:
raise ValueError(
"metric tuple '{}' has fewer than 2 elements".format(key))
if not isinstance(value[0], (tf.Tensor, tf.Variable)):
raise ValueError(
"First element of metric tuple '{}' has value {} and type {}. "
"Must be a Tensor or Variable.".format(key, value[0],
type(value[0])))
if not _is_accepted_dtype(value[0]):
raise ValueError(
"First element of metric '{}' refers to Tensor of the wrong "
"dtype {}. Must be one of tf.bool, tf.int32, tf.float32, "
"tf.float64 or tf.string.".format(key, value[0].dtype))
if not _is_scalar(value[0]):
tf.logging.warn(
"First element of metric '{}' refers to Tensor of rank > 0. "
"AdaNet is currently unable to store metrics of rank > 0 -- this "
"metric will be dropped from the report. "
"value: {}".format(key, value[0]))
continue
if not isinstance(value[1], (tf.Tensor, tf.Operation, tf.Variable)):
raise ValueError(
"Second element of metric tuple '{}' has value {} and type {}. "
"Must be a Tensor, Operation, or Variable.".format(
key, value[1], type(value[1])))
metrics_copy[key] = value
return super(Report, cls).__new__(
cls, hparams=hparams, attributes=attributes, metrics=metrics_copy)