in tf_agents/policies/policy_saver.py [0:0]
def __init__(
self,
policy: tf_policy.TFPolicy,
batch_size: Optional[int] = None,
use_nest_path_signatures: bool = True,
seed: Optional[types.Seed] = None,
train_step: Optional[tf.Variable] = None,
input_fn_and_spec: Optional[InputFnAndSpecType] = None,
metadata: Optional[Dict[Text, tf.Variable]] = None
):
"""Initialize PolicySaver for TF policy `policy`.
Args:
policy: A TF Policy.
batch_size: The number of batch entries the policy will process at a time.
This must be either `None` (unknown batch size) or a python integer.
use_nest_path_signatures: SavedModel spec signatures will be created based
on the sructure of the specs. Otherwise all specs must have unique
names.
seed: Random seed for the `policy.action` call, if any (this should
usually be `None`, except for testing).
train_step: Variable holding the train step for the policy. The value
saved will be set at the time `saver.save` is called. If not provided,
train_step defaults to -1. Note since the train step must be a variable
it is not safe to create it directly in TF1 so in that case this is a
required parameter.
input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
function that takes inputs according to tensor_spec and converts them to
the `(time_step, policy_state)` tuple that is used as the input to the
action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
for the action signature. When `input_fn_and_spec is None`, the action
signature takes as input `(time_step, policy_state)`.
metadata: A dictionary of `tf.Variables` to be saved along with the
policy.
Raises:
TypeError: If `policy` is not an instance of TFPolicy.
TypeError: If `metadata` is not a dictionary of tf.Variables.
ValueError: If use_nest_path_signatures is not used and any of the
following `policy` specs are missing names, or the names collide:
`policy.time_step_spec`, `policy.action_spec`,
`policy.policy_state_spec`, `policy.info_spec`.
ValueError: If `batch_size` is not either `None` or a python integer > 0.
"""
if not isinstance(policy, tf_policy.TFPolicy):
raise TypeError('policy is not a TFPolicy. Saw: %s' % type(policy))
if (batch_size is not None and
(not isinstance(batch_size, int) or batch_size < 1)):
raise ValueError(
'Expected batch_size == None or python int > 0, saw: %s' %
(batch_size,))
self._use_nest_path_signatures = use_nest_path_signatures
action_fn_input_spec = (policy.time_step_spec, policy.policy_state_spec)
if use_nest_path_signatures:
action_fn_input_spec = rename_spec_with_nest_paths(action_fn_input_spec)
else:
_check_spec(action_fn_input_spec)
# Make a shallow copy as we'll be making some changes in-place.
saved_policy = tf.Module()
saved_policy.collect_data_spec = copy.copy(policy.collect_data_spec)
saved_policy.policy_state_spec = copy.copy(policy.policy_state_spec)
if train_step is None:
if not common.has_eager_been_enabled():
raise ValueError('train_step is required in TF1 and must be a '
'`tf.Variable`: %s' % train_step)
train_step = tf.Variable(
-1,
trainable=False,
dtype=tf.int64,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
shape=())
elif not isinstance(train_step, tf.Variable):
raise ValueError('train_step must be a TensorFlow variable: %s' %
train_step)
# We will need the train step for the Checkpoint object.
self._train_step = train_step
saved_policy.train_step = self._train_step
self._metadata = metadata or {}
for key, value in self._metadata.items():
if not isinstance(key, str):
raise TypeError('Keys of metadata must be strings: %s' % key)
if not isinstance(value, tf.Variable):
raise TypeError('Values of metadata must be tf.Variable: %s' % value)
saved_policy.metadata = self._metadata
if batch_size is None:
get_initial_state_fn = policy.get_initial_state
get_initial_state_input_specs = (tf.TensorSpec(
dtype=tf.int32, shape=(), name='batch_size'),)
else:
get_initial_state_fn = functools.partial(
policy.get_initial_state, batch_size=batch_size)
get_initial_state_input_specs = ()
get_initial_state_fn = common.function()(get_initial_state_fn)
original_action_fn = policy.action
if seed is not None:
def action_fn(time_step, policy_state):
time_step = cast(ts.TimeStep, time_step)
return original_action_fn(time_step, policy_state, seed=seed)
else:
action_fn = original_action_fn
def distribution_fn(time_step, policy_state):
"""Wrapper for policy.distribution() in the SavedModel."""
try:
time_step = cast(ts.TimeStep, time_step)
outs = policy.distribution(
time_step=time_step, policy_state=policy_state)
return tf.nest.map_structure(_composite_distribution, outs)
except (TypeError, NotImplementedError) as e:
# TODO(b/156526399): Move this to just the policy.distribution() call
# once tfp.experimental.as_composite() properly handles LinearOperator*
# components as well as TransformedDistributions.
logging.warning(
'WARNING: Could not serialize policy.distribution() for policy '
'"%s". Calling saved_model.distribution() will raise the following '
'assertion error: %s', policy, e)
@common.function()
def _raise():
tf.Assert(False, [str(e)])
return ()
outs = _raise()
# We call get_concrete_function() for its side effect: to ensure the proper
# ConcreteFunction is stored in the SavedModel.
get_initial_state_fn.get_concrete_function(*get_initial_state_input_specs)
train_step_fn = common.function(
lambda: saved_policy.train_step).get_concrete_function()
get_metadata_fn = common.function(
lambda: saved_policy.metadata).get_concrete_function()
batched_time_step_spec = tf.nest.map_structure(
lambda spec: add_batch_dim(spec, [batch_size]), policy.time_step_spec)
batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)
batched_policy_state_spec = tf.nest.map_structure(
lambda spec: add_batch_dim(spec, [batch_size]),
policy.policy_state_spec)
policy_step_spec = policy.policy_step_spec
policy_state_spec = policy.policy_state_spec
if use_nest_path_signatures:
batched_time_step_spec = rename_spec_with_nest_paths(
batched_time_step_spec)
batched_policy_state_spec = rename_spec_with_nest_paths(
batched_policy_state_spec)
policy_step_spec = rename_spec_with_nest_paths(policy_step_spec)
policy_state_spec = rename_spec_with_nest_paths(policy_state_spec)
else:
_check_spec(batched_time_step_spec)
_check_spec(batched_policy_state_spec)
_check_spec(policy_step_spec)
_check_spec(policy_state_spec)
if input_fn_and_spec is not None:
# Store a signature based on input_fn_and_spec
@common.function()
def polymorphic_action_fn(example):
action_inputs = input_fn_and_spec[0](example)
tf.nest.map_structure(_check_compatible, action_fn_input_spec,
action_inputs)
return action_fn(*action_inputs)
@common.function()
def polymorphic_distribution_fn(example):
action_inputs = input_fn_and_spec[0](example)
tf.nest.map_structure(_check_compatible, action_fn_input_spec,
action_inputs)
return distribution_fn(*action_inputs)
batched_input_spec = tf.nest.map_structure(
lambda spec: add_batch_dim(spec, [batch_size]), input_fn_and_spec[1])
# We call get_concrete_function() for its side effect: to ensure the
# proper ConcreteFunction is stored in the SavedModel.
polymorphic_action_fn.get_concrete_function(example=batched_input_spec)
polymorphic_distribution_fn.get_concrete_function(
example=batched_input_spec)
action_input_spec = (input_fn_and_spec[1],)
else:
action_input_spec = action_fn_input_spec
if batched_policy_state_spec:
# Store the signature with a required policy state spec
polymorphic_action_fn = common.function()(action_fn)
polymorphic_action_fn.get_concrete_function(
time_step=batched_time_step_spec,
policy_state=batched_policy_state_spec)
polymorphic_distribution_fn = common.function()(distribution_fn)
polymorphic_distribution_fn.get_concrete_function(
time_step=batched_time_step_spec,
policy_state=batched_policy_state_spec)
else:
# Create a polymorphic action_fn which you can call as
# restored.action(time_step)
# or
# restored.action(time_step, ())
# (without retracing the inner action twice)
@common.function()
def polymorphic_action_fn(time_step,
policy_state=batched_policy_state_spec):
return action_fn(time_step, policy_state)
polymorphic_action_fn.get_concrete_function(
time_step=batched_time_step_spec,
policy_state=batched_policy_state_spec)
polymorphic_action_fn.get_concrete_function(
time_step=batched_time_step_spec)
@common.function()
def polymorphic_distribution_fn(time_step,
policy_state=batched_policy_state_spec):
return distribution_fn(time_step, policy_state)
polymorphic_distribution_fn.get_concrete_function(
time_step=batched_time_step_spec,
policy_state=batched_policy_state_spec)
polymorphic_distribution_fn.get_concrete_function(
time_step=batched_time_step_spec)
signatures = {
# CompositeTensors aren't well supported by old-style signature
# mechanisms, so we do not have a signature for policy.distribution.
'action':
_function_with_flat_signature(
polymorphic_action_fn,
input_specs=action_input_spec,
output_spec=policy_step_spec,
include_batch_dimension=True,
batch_size=batch_size),
'get_initial_state':
_function_with_flat_signature(
get_initial_state_fn,
input_specs=get_initial_state_input_specs,
output_spec=policy_state_spec,
include_batch_dimension=False),
'get_train_step':
_function_with_flat_signature(
train_step_fn,
input_specs=(),
output_spec=train_step.dtype,
include_batch_dimension=False),
'get_metadata':
_function_with_flat_signature(
get_metadata_fn,
input_specs=(),
output_spec=tf.nest.map_structure(lambda v: v.dtype,
self._metadata),
include_batch_dimension=False),
}
saved_policy.action = polymorphic_action_fn
saved_policy.distribution = polymorphic_distribution_fn
saved_policy.get_initial_state = get_initial_state_fn
saved_policy.get_train_step = train_step_fn
saved_policy.get_metadata = get_metadata_fn
# Adding variables as an attribute to facilitate updating them.
saved_policy.model_variables = policy.variables()
# TODO(b/156779400): Move to a public API for accessing all trackable leaf
# objects (once it's available). For now, we have no other way of tracking
# objects like Tables, Vocabulary files, etc.
try:
saved_policy._all_assets = {
name: ref
for name, ref in policy._unconditional_checkpoint_dependencies} # pylint: disable=protected-access
except AttributeError as e:
if '_self_unconditional' in str(e):
logging.warning(
'Unable to capture all trackable objects in policy "%s". This '
'may be okay. Error: %s', policy, e)
else:
raise e
self._policy = saved_policy
self._raw_policy = policy
self._batch_size = batch_size
self._signatures = signatures
self._action_input_spec = action_input_spec
self._policy_step_spec = policy_step_spec
self._policy_state_spec = policy_state_spec