in tensorflow_estimator/python/estimator/head/base_head.py [0:0]
def get_weights_and_check_match_logits(features,
weight_column,
logits,
allow_per_logit_weights=False):
"""Fetches weights from features and checks that the shape matches logits.
Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape
can be either:
* [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.
* [D0, D1, ... DN, 1]
* [D0, D1, ... DN]: In this case, weights is reshaped into
[D0, D1, ... DN, 1] to work with weight broadcasting rules.
Args:
features: The features dict that contains weights.
weight_column: The weight column. If not given, this method returns 1.
logits: logits Tensor.
allow_per_logit_weights: Boolean. Whether we allow weights along the logits
dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.
Returns:
Validated and reshaped weights Tensor.
Raises:
ValueError: If the weights `Tensor` cannot be cast into float.
"""
if allow_per_logit_weights:
err_msg = ('weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
'[D0, D1, ... DN, logits_dimension]')
else:
err_msg = ('weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')
with ops.name_scope(
'weights', values=tuple(six.itervalues(features)) + (logits,)) as scope:
# Fetch the weights.
if weight_column is None:
return 1.
# TODO(b/117839674): update feature_column
if isinstance(weight_column, six.string_types):
weight_column = tf.feature_column.numeric_column(
key=weight_column, shape=(1,))
if not isinstance(weight_column,
(feature_column_lib.NumericColumn, _NumericColumn)):
raise TypeError('Weight column must be either a string or NumericColumn.'
' Given type: {}.'.format(type(weight_column)))
weights = weight_column._get_dense_tensor( # pylint: disable=protected-access
_LazyBuilder(features))
if not (weights.dtype.is_floating or weights.dtype.is_integer):
raise ValueError('Weight column should be castable to float. '
'Given dtype: {}'.format(weights.dtype))
weights = tf.cast(weights, name='weights', dtype=tf.dtypes.float32)
# Validate the weights shape.
# Eager mode.
if tf.executing_eagerly():
weights_shape = weights._shape_tuple() # pylint: disable=protected-access
logits_shape = logits._shape_tuple() # pylint: disable=protected-access
weights_rank = weights._rank() # pylint: disable=protected-access
logits_rank = logits._rank() # pylint: disable=protected-access
if (weights_rank is not None and logits_rank is not None and
weights_rank == logits_rank - 1):
if logits_shape[:-1] != weights_shape:
raise ValueError('{}, logits_shape: {}. weights_shape: {}.'.format(
err_msg, logits_shape, weights_shape))
return tf.compat.v1.expand_dims(weights, -1, name=scope)
supported_weights_shape = logits_shape[:-1] + (1,)
if allow_per_logit_weights:
if (logits_shape != weights_shape and
supported_weights_shape != weights_shape):
raise ValueError('{}, logits_shape: {}. weights_shape: {}.'.format(
err_msg, logits_shape, weights_shape))
else:
if supported_weights_shape != weights_shape:
raise ValueError('{}, logits_shape: {}. weights_shape: {}.'.format(
err_msg, logits_shape, weights_shape))
return weights
# Graph mode.
weights_shape = tf.compat.v1.shape(weights, name='weights_shape')
logits_shape = tf.compat.v1.shape(logits, name='logits_shape')
if (weights.shape.ndims is not None and logits.shape.ndims is not None and
weights.shape.ndims == logits.shape.ndims - 1):
assert_dimension = tf.compat.v1.debugging.assert_equal(
logits_shape[:-1],
weights_shape,
message=err_msg,
data=[
'logits_shape: ', logits_shape, 'weights_shape: ', weights_shape
])
with tf.control_dependencies([assert_dimension]):
return tf.compat.v1.expand_dims(weights, -1, name=scope)
supported_weights_shape = tf.concat([logits_shape[:-1], [1]], axis=0)
if allow_per_logit_weights:
condition = tf.math.reduce_any([
tf.reduce_all(tf.math.equal(logits_shape, weights_shape)),
tf.reduce_all(tf.math.equal(supported_weights_shape, weights_shape))
])
assert_dimension = tf.debugging.Assert(
condition=condition,
data=[
err_msg, 'logits_shape: ', logits_shape, 'weights_shape: ',
weights_shape
])
else:
assert_dimension = tf.compat.v1.debugging.assert_equal(
supported_weights_shape,
weights_shape,
message=err_msg,
data=[
'logits_shape: ', logits_shape, 'weights_shape: ', weights_shape
])
with tf.control_dependencies([assert_dimension]):
return tf.identity(weights, name=scope)