def soft_variables_update()

in tf_agents/utils/common.py [0:0]


def soft_variables_update(source_variables,
                          target_variables,
                          tau=1.0,
                          tau_non_trainable=None,
                          sort_variables_by_name=False):
  """Performs a soft/hard update of variables from the source to the target.

  Note: **when using this function with TF DistributionStrategy**, the
  `strategy.extended.update` call (below) needs to be done in a cross-replica
  context, i.e. inside a merge_call. Please use the Periodically class above
  that provides this wrapper for you.

  For each variable v_t in target variables and its corresponding variable v_s
  in source variables, a soft update is:
  v_t = (1 - tau) * v_t + tau * v_s

  When tau is 1.0 (the default), then it does a hard update:
  v_t = v_s

  Args:
    source_variables: list of source variables.
    target_variables: list of target variables.
    tau: A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
      update. This is used for trainable variables.
    tau_non_trainable: A float scalar in [0, 1] for non_trainable variables. If
      None, will copy from tau.
    sort_variables_by_name: A bool, when True would sort the variables by name
      before doing the update.

  Returns:
    An operation that updates target variables from source variables.

  Raises:
    ValueError: if `tau not in [0, 1]`.
    ValueError: if `len(source_variables) != len(target_variables)`.
    ValueError: "Method requires being in cross-replica context,
      use get_replica_context().merge_call()" if used inside replica context.
  """
  if tau < 0 or tau > 1:
    raise ValueError('Input `tau` should be in [0, 1].')
  if tau_non_trainable is None:
    tau_non_trainable = tau

  if tau_non_trainable < 0 or tau_non_trainable > 1:
    raise ValueError('Input `tau_non_trainable` should be in [0, 1].')

  updates = []

  op_name = 'soft_variables_update'
  if tau == 0.0 or not source_variables or not target_variables:
    return tf.no_op(name=op_name)
  if len(source_variables) != len(target_variables):
    raise ValueError(
        'Source and target variable lists have different lengths: '
        '{} vs. {}'.format(len(source_variables), len(target_variables)))
  if sort_variables_by_name:
    source_variables = sorted(source_variables, key=lambda x: x.name)
    target_variables = sorted(target_variables, key=lambda x: x.name)

  strategy = tf.distribute.get_strategy()

  for (v_s, v_t) in zip(source_variables, target_variables):
    v_t.shape.assert_is_compatible_with(v_s.shape)

    def update_fn(v1, v2):
      """Update variables."""
      # For not trainable variables do hard updates.
      # This helps stabilaze BatchNorm moving averagees TODO(b/144455039)
      if not v1.trainable:
        current_tau = tau_non_trainable
      else:
        current_tau = tau

      if current_tau == 1.0:
        return v1.assign(v2)
      else:
        return v1.assign((1 - current_tau) * v1 + current_tau * v2)

    # TODO(b/142508640): remove this when b/142802462 is fixed.
    # Workaround for b/142508640, only use extended.update for
    # MirroredVariable variables (which are trainable variables).
    # For other types of variables (i.e. SyncOnReadVariables, for example
    # batch norm stats) do a regular assign, which will cause a sync and
    # broadcast from replica 0, so will have slower performance but will be
    # correct and not cause a failure.
    if tf.distribute.has_strategy() and v_t.trainable:
      # Assignment happens independently on each replica,
      # see b/140690837 #46.
      update = strategy.extended.update(v_t, update_fn, args=(v_s,))
    else:
      update = update_fn(v_t, v_s)

    updates.append(update)
  return tf.group(*updates, name=op_name)