def aggregate_losses()

in tf_agents/utils/common.py [0:0]


def aggregate_losses(per_example_loss=None,
                     sample_weight=None,
                     global_batch_size=None,
                     regularization_loss=None):
  """Aggregates and scales per example loss and regularization losses.

  If `global_batch_size` is given it would be used for scaling, otherwise it
  would use the batch_dim of per_example_loss and number of replicas.

  Args:
    per_example_loss: Per-example loss [B] or [B, T, ...].
    sample_weight: Optional weighting for each example, Tensor shaped [B] or
      [B, T, ...], or a scalar float.
    global_batch_size: Optional global batch size value. Defaults to (size of
    first dimension of `losses`) * (number of replicas).
    regularization_loss: Regularization loss.

  Returns:
    An AggregatedLosses named tuple with scalar losses to optimize.
  """
  total_loss, weighted_loss, reg_loss = None, None, None
  if sample_weight is not None and not isinstance(sample_weight, tf.Tensor):
    sample_weight = tf.convert_to_tensor(sample_weight, dtype=tf.float32)

  # Compute loss that is scaled by global batch size.
  if per_example_loss is not None:
    loss_rank = per_example_loss.shape.rank
    if sample_weight is not None:
      weight_rank = sample_weight.shape.rank
      # Expand `sample_weight` to be broadcastable to the shape of
      # `per_example_loss`, to ensure that multiplication works properly.
      if weight_rank > 0 and loss_rank > weight_rank:
        for dim in range(weight_rank, loss_rank):
          sample_weight = tf.expand_dims(sample_weight, dim)
      # Sometimes we have an episode boundary or similar, and at this location
      # the loss is nonsensical (i.e., inf or nan); and sample_weight is zero.
      # In this case, we should respect the zero sample_weight and ignore the
      # frame.
      per_example_loss = tf.math.multiply_no_nan(
          per_example_loss, sample_weight)

    if loss_rank is not None and loss_rank == 0:
      err_msg = (
          'Need to use a loss function that computes losses per sample, ex: '
          'replace losses.mean_squared_error with tf.math.squared_difference. '
          'Invalid value passed for `per_example_loss`. Expected a tensor '
          'tensor with at least rank 1, received: {}'.format(per_example_loss))
      if tf.distribute.has_strategy():
        raise ValueError(err_msg)
      else:
        logging.warning(err_msg)
        # Add extra dimension to prevent error in compute_average_loss.
        per_example_loss = tf.expand_dims(per_example_loss, 0)
    elif loss_rank > 1:
      # If per_example_loss is shaped [B, T, ...], we need to compute the mean
      # across the extra dimensions, ex. time, as well.
      per_example_loss = tf.reduce_mean(per_example_loss, range(1, loss_rank))

    global_batch_size = global_batch_size and tf.cast(global_batch_size,
                                                      per_example_loss.dtype)
    weighted_loss = tf.nn.compute_average_loss(
        per_example_loss,
        global_batch_size=global_batch_size)
    total_loss = weighted_loss
  # Add scaled regularization losses.
  if regularization_loss is not None:
    reg_loss = tf.nn.scale_regularization_loss(regularization_loss)
    if total_loss is None:
      total_loss = reg_loss
    else:
      total_loss += reg_loss
  return AggregatedLosses(total_loss, weighted_loss, reg_loss)