in tf_agents/utils/common.py [0:0]
def aggregate_losses(per_example_loss=None,
sample_weight=None,
global_batch_size=None,
regularization_loss=None):
"""Aggregates and scales per example loss and regularization losses.
If `global_batch_size` is given it would be used for scaling, otherwise it
would use the batch_dim of per_example_loss and number of replicas.
Args:
per_example_loss: Per-example loss [B] or [B, T, ...].
sample_weight: Optional weighting for each example, Tensor shaped [B] or
[B, T, ...], or a scalar float.
global_batch_size: Optional global batch size value. Defaults to (size of
first dimension of `losses`) * (number of replicas).
regularization_loss: Regularization loss.
Returns:
An AggregatedLosses named tuple with scalar losses to optimize.
"""
total_loss, weighted_loss, reg_loss = None, None, None
if sample_weight is not None and not isinstance(sample_weight, tf.Tensor):
sample_weight = tf.convert_to_tensor(sample_weight, dtype=tf.float32)
# Compute loss that is scaled by global batch size.
if per_example_loss is not None:
loss_rank = per_example_loss.shape.rank
if sample_weight is not None:
weight_rank = sample_weight.shape.rank
# Expand `sample_weight` to be broadcastable to the shape of
# `per_example_loss`, to ensure that multiplication works properly.
if weight_rank > 0 and loss_rank > weight_rank:
for dim in range(weight_rank, loss_rank):
sample_weight = tf.expand_dims(sample_weight, dim)
# Sometimes we have an episode boundary or similar, and at this location
# the loss is nonsensical (i.e., inf or nan); and sample_weight is zero.
# In this case, we should respect the zero sample_weight and ignore the
# frame.
per_example_loss = tf.math.multiply_no_nan(
per_example_loss, sample_weight)
if loss_rank is not None and loss_rank == 0:
err_msg = (
'Need to use a loss function that computes losses per sample, ex: '
'replace losses.mean_squared_error with tf.math.squared_difference. '
'Invalid value passed for `per_example_loss`. Expected a tensor '
'tensor with at least rank 1, received: {}'.format(per_example_loss))
if tf.distribute.has_strategy():
raise ValueError(err_msg)
else:
logging.warning(err_msg)
# Add extra dimension to prevent error in compute_average_loss.
per_example_loss = tf.expand_dims(per_example_loss, 0)
elif loss_rank > 1:
# If per_example_loss is shaped [B, T, ...], we need to compute the mean
# across the extra dimensions, ex. time, as well.
per_example_loss = tf.reduce_mean(per_example_loss, range(1, loss_rank))
global_batch_size = global_batch_size and tf.cast(global_batch_size,
per_example_loss.dtype)
weighted_loss = tf.nn.compute_average_loss(
per_example_loss,
global_batch_size=global_batch_size)
total_loss = weighted_loss
# Add scaled regularization losses.
if regularization_loss is not None:
reg_loss = tf.nn.scale_regularization_loss(regularization_loss)
if total_loss is None:
total_loss = reg_loss
else:
total_loss += reg_loss
return AggregatedLosses(total_loss, weighted_loss, reg_loss)