in tf_agents/utils/common.py [0:0]
def soft_variables_update(source_variables,
target_variables,
tau=1.0,
tau_non_trainable=None,
sort_variables_by_name=False):
"""Performs a soft/hard update of variables from the source to the target.
Note: **when using this function with TF DistributionStrategy**, the
`strategy.extended.update` call (below) needs to be done in a cross-replica
context, i.e. inside a merge_call. Please use the Periodically class above
that provides this wrapper for you.
For each variable v_t in target variables and its corresponding variable v_s
in source variables, a soft update is:
v_t = (1 - tau) * v_t + tau * v_s
When tau is 1.0 (the default), then it does a hard update:
v_t = v_s
Args:
source_variables: list of source variables.
target_variables: list of target variables.
tau: A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
update. This is used for trainable variables.
tau_non_trainable: A float scalar in [0, 1] for non_trainable variables. If
None, will copy from tau.
sort_variables_by_name: A bool, when True would sort the variables by name
before doing the update.
Returns:
An operation that updates target variables from source variables.
Raises:
ValueError: if `tau not in [0, 1]`.
ValueError: if `len(source_variables) != len(target_variables)`.
ValueError: "Method requires being in cross-replica context,
use get_replica_context().merge_call()" if used inside replica context.
"""
if tau < 0 or tau > 1:
raise ValueError('Input `tau` should be in [0, 1].')
if tau_non_trainable is None:
tau_non_trainable = tau
if tau_non_trainable < 0 or tau_non_trainable > 1:
raise ValueError('Input `tau_non_trainable` should be in [0, 1].')
updates = []
op_name = 'soft_variables_update'
if tau == 0.0 or not source_variables or not target_variables:
return tf.no_op(name=op_name)
if len(source_variables) != len(target_variables):
raise ValueError(
'Source and target variable lists have different lengths: '
'{} vs. {}'.format(len(source_variables), len(target_variables)))
if sort_variables_by_name:
source_variables = sorted(source_variables, key=lambda x: x.name)
target_variables = sorted(target_variables, key=lambda x: x.name)
strategy = tf.distribute.get_strategy()
for (v_s, v_t) in zip(source_variables, target_variables):
v_t.shape.assert_is_compatible_with(v_s.shape)
def update_fn(v1, v2):
"""Update variables."""
# For not trainable variables do hard updates.
# This helps stabilaze BatchNorm moving averagees TODO(b/144455039)
if not v1.trainable:
current_tau = tau_non_trainable
else:
current_tau = tau
if current_tau == 1.0:
return v1.assign(v2)
else:
return v1.assign((1 - current_tau) * v1 + current_tau * v2)
# TODO(b/142508640): remove this when b/142802462 is fixed.
# Workaround for b/142508640, only use extended.update for
# MirroredVariable variables (which are trainable variables).
# For other types of variables (i.e. SyncOnReadVariables, for example
# batch norm stats) do a regular assign, which will cause a sync and
# broadcast from replica 0, so will have slower performance but will be
# correct and not cause a failure.
if tf.distribute.has_strategy() and v_t.trainable:
# Assignment happens independently on each replica,
# see b/140690837 #46.
update = strategy.extended.update(v_t, update_fn, args=(v_s,))
else:
update = update_fn(v_t, v_s)
updates.append(update)
return tf.group(*updates, name=op_name)