def train()

in community-content/tf_agents_bandits_movie_recommendation_with_kfp_and_vertex_sdk/step_by_step_sdk_tf_agents_bandits_movie_recommendation/src/training/policy_util.py [0:0]


def train(agent: TFAgent,
          environment: TFEnvironment,
          training_loops: int,
          steps_per_loop: int,
          additional_metrics: Optional[List[TFStepMetric]] = None,
          training_data_spec_transformation_fn: Optional[Callable[[T],
                                                                  T]] = None,
          run_hyperparameter_tuning: bool = False,
          root_dir: Optional[str] = None,
          artifacts_dir: Optional[str] = None) -> Dict[str, List[float]]:
  """Performs `training_loops` iterations of training on the agent's policy.

  Uses the `environment` as the problem formulation and source of immediate
  feedback and the agent's algorithm, to perform `training-loops` iterations
  of on-policy training on the policy. Has hyperparameter mode and regular
  training mode.
  If one or more baseline_reward_fns are provided, the regret is computed
  against each one of them. Here is example baseline_reward_fn:
  def baseline_reward_fn(observation, per_action_reward_fns):
   rewards = ... # compute reward for each arm
   optimal_action_reward = ... # take the maximum reward
   return optimal_action_reward

  Args:
    agent: An instance of `TFAgent`.
    environment: An instance of `TFEnvironment`.
    training_loops: An integer indicating how many training loops should be run.
    steps_per_loop: An integer indicating how many driver steps should be
      executed and presented to the trainer during each training loop.
    additional_metrics: Optional; list of metric objects to log, in addition to
      default metrics `NumberOfEpisodes`, `AverageReturnMetric`, and
      `AverageEpisodeLengthMetric`.
    training_data_spec_transformation_fn: Optional; function that transforms
      the data items before they get to the replay buffer.
    run_hyperparameter_tuning: Optional; whether this training logic is
      executed for the purpose of hyperparameter tuning. If so, then it does
      not save model artifacts.
    root_dir: Optional; path to the directory where training artifacts are
      written; usually used for a default or auto-generated location. Do not
      specify this argument if using hyperparameter tuning instead of training.
    artifacts_dir: Optional; path to an extra directory where training
      artifacts are written; usually used for a mutually agreed location from
      which artifacts will be loaded. Do not specify this argument if using
      hyperparameter tuning instead of training.

  Returns:
    A dict mapping metric names (eg. "AverageReturnMetric") to a list of
    intermediate metric values over `training_loops` iterations of training.
  """
  if run_hyperparameter_tuning and not (root_dir is None and
                                        artifacts_dir is None):
    raise ValueError("Do not specify `root_dir` or `artifacts_dir` when" +
                     " running hyperparameter tuning.")

  if training_data_spec_transformation_fn is None:
    data_spec = agent.policy.trajectory_spec
  else:
    data_spec = training_data_spec_transformation_fn(
        agent.policy.trajectory_spec)
  replay_buffer = trainer.get_replay_buffer(data_spec, environment.batch_size,
                                            steps_per_loop)

  # `step_metric` records the number of individual rounds of bandit interaction;
  # that is, (number of trajectories) * batch_size.
  step_metric = tf_metrics.EnvironmentSteps()
  metrics = [
      tf_metrics.NumberOfEpisodes(),
      tf_metrics.AverageEpisodeLengthMetric(batch_size=environment.batch_size)
  ]
  if additional_metrics:
    metrics += additional_metrics

  if isinstance(environment.reward_spec(), dict):
    metrics += [tf_metrics.AverageReturnMultiMetric(
        reward_spec=environment.reward_spec(),
        batch_size=environment.batch_size)]
  else:
    metrics += [
        tf_metrics.AverageReturnMetric(batch_size=environment.batch_size)]

  # Store intermediate metric results, indexed by metric names.
  metric_results = collections.defaultdict(list)

  if training_data_spec_transformation_fn is not None:
    add_batch_fn = lambda data: replay_buffer.add_batch(  # pylint: disable=g-long-lambda
        training_data_spec_transformation_fn(data))
  else:
    add_batch_fn = replay_buffer.add_batch

  observers = [add_batch_fn, step_metric] + metrics

  driver = dynamic_step_driver.DynamicStepDriver(
      env=environment,
      policy=agent.collect_policy,
      num_steps=steps_per_loop * environment.batch_size,
      observers=observers)

  training_loop = trainer.get_training_loop_fn(
      driver, replay_buffer, agent, steps_per_loop)
  if not run_hyperparameter_tuning:
    saver = policy_saver.PolicySaver(agent.policy)

  for _ in range(training_loops):
    training_loop()
    metric_utils.log_metrics(metrics)
    for metric in metrics:
      metric.tf_summaries(train_step=step_metric.result())
      metric_results[type(metric).__name__].append(metric.result().numpy())
  if not run_hyperparameter_tuning:
    saver.save(root_dir)
    saver.save(artifacts_dir)
  return metric_results