def __init__()

in tf_agents/policies/policy_saver.py [0:0]


  def __init__(
      self,
      policy: tf_policy.TFPolicy,
      batch_size: Optional[int] = None,
      use_nest_path_signatures: bool = True,
      seed: Optional[types.Seed] = None,
      train_step: Optional[tf.Variable] = None,
      input_fn_and_spec: Optional[InputFnAndSpecType] = None,
      metadata: Optional[Dict[Text, tf.Variable]] = None
      ):
    """Initialize PolicySaver for  TF policy `policy`.

    Args:
      policy: A TF Policy.
      batch_size: The number of batch entries the policy will process at a time.
        This must be either `None` (unknown batch size) or a python integer.
      use_nest_path_signatures: SavedModel spec signatures will be created based
        on the sructure of the specs. Otherwise all specs must have unique
        names.
      seed: Random seed for the `policy.action` call, if any (this should
        usually be `None`, except for testing).
      train_step: Variable holding the train step for the policy. The value
        saved will be set at the time `saver.save` is called. If not provided,
        train_step defaults to -1. Note since the train step must be a variable
        it is not safe to create it directly in TF1 so in that case this is a
        required parameter.
      input_fn_and_spec: A `(input_fn, tensor_spec)` tuple where input_fn is a
        function that takes inputs according to tensor_spec and converts them to
        the `(time_step, policy_state)` tuple that is used as the input to the
        action_fn. When `input_fn_and_spec` is set, `tensor_spec` is the input
        for the action signature. When `input_fn_and_spec is None`, the action
        signature takes as input `(time_step, policy_state)`.
      metadata: A dictionary of `tf.Variables` to be saved along with the
        policy.

    Raises:
      TypeError: If `policy` is not an instance of TFPolicy.
      TypeError: If `metadata` is not a dictionary of tf.Variables.
      ValueError: If use_nest_path_signatures is not used and any of the
        following `policy` specs are missing names, or the names collide:
        `policy.time_step_spec`, `policy.action_spec`,
        `policy.policy_state_spec`, `policy.info_spec`.
      ValueError: If `batch_size` is not either `None` or a python integer > 0.
    """
    if not isinstance(policy, tf_policy.TFPolicy):
      raise TypeError('policy is not a TFPolicy.  Saw: %s' % type(policy))
    if (batch_size is not None and
        (not isinstance(batch_size, int) or batch_size < 1)):
      raise ValueError(
          'Expected batch_size == None or python int > 0, saw: %s' %
          (batch_size,))

    self._use_nest_path_signatures = use_nest_path_signatures

    action_fn_input_spec = (policy.time_step_spec, policy.policy_state_spec)
    if use_nest_path_signatures:
      action_fn_input_spec = rename_spec_with_nest_paths(action_fn_input_spec)
    else:
      _check_spec(action_fn_input_spec)

    # Make a shallow copy as we'll be making some changes in-place.
    saved_policy = tf.Module()
    saved_policy.collect_data_spec = copy.copy(policy.collect_data_spec)
    saved_policy.policy_state_spec = copy.copy(policy.policy_state_spec)

    if train_step is None:
      if not common.has_eager_been_enabled():
        raise ValueError('train_step is required in TF1 and must be a '
                         '`tf.Variable`: %s' % train_step)
      train_step = tf.Variable(
          -1,
          trainable=False,
          dtype=tf.int64,
          aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
          shape=())
    elif not isinstance(train_step, tf.Variable):
      raise ValueError('train_step must be a TensorFlow variable: %s' %
                       train_step)

    # We will need the train step for the Checkpoint object.
    self._train_step = train_step
    saved_policy.train_step = self._train_step

    self._metadata = metadata or {}
    for key, value in self._metadata.items():
      if not isinstance(key, str):
        raise TypeError('Keys of metadata must be strings: %s' % key)
      if not isinstance(value, tf.Variable):
        raise TypeError('Values of metadata must be tf.Variable: %s' % value)
    saved_policy.metadata = self._metadata

    if batch_size is None:
      get_initial_state_fn = policy.get_initial_state
      get_initial_state_input_specs = (tf.TensorSpec(
          dtype=tf.int32, shape=(), name='batch_size'),)
    else:
      get_initial_state_fn = functools.partial(
          policy.get_initial_state, batch_size=batch_size)
      get_initial_state_input_specs = ()

    get_initial_state_fn = common.function()(get_initial_state_fn)

    original_action_fn = policy.action

    if seed is not None:

      def action_fn(time_step, policy_state):
        time_step = cast(ts.TimeStep, time_step)
        return original_action_fn(time_step, policy_state, seed=seed)
    else:
      action_fn = original_action_fn

    def distribution_fn(time_step, policy_state):
      """Wrapper for policy.distribution() in the SavedModel."""
      try:
        time_step = cast(ts.TimeStep, time_step)
        outs = policy.distribution(
            time_step=time_step, policy_state=policy_state)
        return tf.nest.map_structure(_composite_distribution, outs)
      except (TypeError, NotImplementedError) as e:
        # TODO(b/156526399): Move this to just the policy.distribution() call
        # once tfp.experimental.as_composite() properly handles LinearOperator*
        # components as well as TransformedDistributions.
        logging.warning(
            'WARNING: Could not serialize policy.distribution() for policy '
            '"%s". Calling saved_model.distribution() will raise the following '
            'assertion error: %s', policy, e)
        @common.function()
        def _raise():
          tf.Assert(False, [str(e)])
          return ()
        outs = _raise()

    # We call get_concrete_function() for its side effect: to ensure the proper
    # ConcreteFunction is stored in the SavedModel.
    get_initial_state_fn.get_concrete_function(*get_initial_state_input_specs)

    train_step_fn = common.function(
        lambda: saved_policy.train_step).get_concrete_function()
    get_metadata_fn = common.function(
        lambda: saved_policy.metadata).get_concrete_function()

    batched_time_step_spec = tf.nest.map_structure(
        lambda spec: add_batch_dim(spec, [batch_size]), policy.time_step_spec)
    batched_time_step_spec = cast(ts.TimeStep, batched_time_step_spec)
    batched_policy_state_spec = tf.nest.map_structure(
        lambda spec: add_batch_dim(spec, [batch_size]),
        policy.policy_state_spec)

    policy_step_spec = policy.policy_step_spec
    policy_state_spec = policy.policy_state_spec

    if use_nest_path_signatures:
      batched_time_step_spec = rename_spec_with_nest_paths(
          batched_time_step_spec)
      batched_policy_state_spec = rename_spec_with_nest_paths(
          batched_policy_state_spec)
      policy_step_spec = rename_spec_with_nest_paths(policy_step_spec)
      policy_state_spec = rename_spec_with_nest_paths(policy_state_spec)
    else:
      _check_spec(batched_time_step_spec)
      _check_spec(batched_policy_state_spec)
      _check_spec(policy_step_spec)
      _check_spec(policy_state_spec)

    if input_fn_and_spec is not None:
      # Store a signature based on input_fn_and_spec
      @common.function()
      def polymorphic_action_fn(example):
        action_inputs = input_fn_and_spec[0](example)
        tf.nest.map_structure(_check_compatible, action_fn_input_spec,
                              action_inputs)
        return action_fn(*action_inputs)

      @common.function()
      def polymorphic_distribution_fn(example):
        action_inputs = input_fn_and_spec[0](example)
        tf.nest.map_structure(_check_compatible, action_fn_input_spec,
                              action_inputs)
        return distribution_fn(*action_inputs)

      batched_input_spec = tf.nest.map_structure(
          lambda spec: add_batch_dim(spec, [batch_size]), input_fn_and_spec[1])
      # We call get_concrete_function() for its side effect: to ensure the
      # proper ConcreteFunction is stored in the SavedModel.
      polymorphic_action_fn.get_concrete_function(example=batched_input_spec)
      polymorphic_distribution_fn.get_concrete_function(
          example=batched_input_spec)

      action_input_spec = (input_fn_and_spec[1],)

    else:
      action_input_spec = action_fn_input_spec
      if batched_policy_state_spec:
        # Store the signature with a required policy state spec
        polymorphic_action_fn = common.function()(action_fn)
        polymorphic_action_fn.get_concrete_function(
            time_step=batched_time_step_spec,
            policy_state=batched_policy_state_spec)

        polymorphic_distribution_fn = common.function()(distribution_fn)
        polymorphic_distribution_fn.get_concrete_function(
            time_step=batched_time_step_spec,
            policy_state=batched_policy_state_spec)
      else:
        # Create a polymorphic action_fn which you can call as
        #  restored.action(time_step)
        # or
        #  restored.action(time_step, ())
        # (without retracing the inner action twice)
        @common.function()
        def polymorphic_action_fn(time_step,
                                  policy_state=batched_policy_state_spec):
          return action_fn(time_step, policy_state)

        polymorphic_action_fn.get_concrete_function(
            time_step=batched_time_step_spec,
            policy_state=batched_policy_state_spec)
        polymorphic_action_fn.get_concrete_function(
            time_step=batched_time_step_spec)

        @common.function()
        def polymorphic_distribution_fn(time_step,
                                        policy_state=batched_policy_state_spec):
          return distribution_fn(time_step, policy_state)

        polymorphic_distribution_fn.get_concrete_function(
            time_step=batched_time_step_spec,
            policy_state=batched_policy_state_spec)
        polymorphic_distribution_fn.get_concrete_function(
            time_step=batched_time_step_spec)

    signatures = {
        # CompositeTensors aren't well supported by old-style signature
        # mechanisms, so we do not have a signature for policy.distribution.
        'action':
            _function_with_flat_signature(
                polymorphic_action_fn,
                input_specs=action_input_spec,
                output_spec=policy_step_spec,
                include_batch_dimension=True,
                batch_size=batch_size),
        'get_initial_state':
            _function_with_flat_signature(
                get_initial_state_fn,
                input_specs=get_initial_state_input_specs,
                output_spec=policy_state_spec,
                include_batch_dimension=False),
        'get_train_step':
            _function_with_flat_signature(
                train_step_fn,
                input_specs=(),
                output_spec=train_step.dtype,
                include_batch_dimension=False),
        'get_metadata':
            _function_with_flat_signature(
                get_metadata_fn,
                input_specs=(),
                output_spec=tf.nest.map_structure(lambda v: v.dtype,
                                                  self._metadata),
                include_batch_dimension=False),
    }

    saved_policy.action = polymorphic_action_fn
    saved_policy.distribution = polymorphic_distribution_fn
    saved_policy.get_initial_state = get_initial_state_fn
    saved_policy.get_train_step = train_step_fn
    saved_policy.get_metadata = get_metadata_fn
    # Adding variables as an attribute to facilitate updating them.
    saved_policy.model_variables = policy.variables()

    # TODO(b/156779400): Move to a public API for accessing all trackable leaf
    # objects (once it's available).  For now, we have no other way of tracking
    # objects like Tables, Vocabulary files, etc.
    try:
      saved_policy._all_assets = {
          name: ref
          for name, ref in policy._unconditional_checkpoint_dependencies}  # pylint: disable=protected-access
    except AttributeError as e:
      if '_self_unconditional' in str(e):
        logging.warning(
            'Unable to capture all trackable objects in policy "%s".  This '
            'may be okay.  Error: %s', policy, e)
      else:
        raise e

    self._policy = saved_policy
    self._raw_policy = policy
    self._batch_size = batch_size
    self._signatures = signatures
    self._action_input_spec = action_input_spec
    self._policy_step_spec = policy_step_spec
    self._policy_state_spec = policy_state_spec