def run()

in mesh_tensorflow/transformer/utils.py [0:0]


def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=None,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        eval_checkpoint_step=None,
        export_checkpoint_step=None,
        export_path="",
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=5000,
        keep_checkpoint_max=None,
        eval_summary_dir=None,
        batch_size=("tokens_per_replica", 2048),
        train_steps=auto_train_steps,
        total_run_steps=None,
        sequence_length=None,
        mesh_shape=gin.REQUIRED,
        mesh_devices=None,
        layout_rules=gin.REQUIRED,
        learning_rate_schedule=None,
        optimizer=None,
        predict_fn=None,
        variable_filter=None,
        perplexity_eval_steps=100,
        init_checkpoint=None,
        ensemble_inputs=None,
        train_model_fn=train_model,
        skip_seen_data=False,
        seen_data_init_step=0,
        output_eval_examples=True,
        checkpoint_input_pipeline=False,
        eval_dir_suffix=None):
  """Run training, eval, or inference depending on `mode`.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string, see `get_estimator` docstring for details.
    vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
      targets_vocabulary) tuple.
    train_dataset_fn: A function returning a tf.data.Dataset, see `train_model`
      docstring for details.
    eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
      See `eval_model` docstring for details.
    dataset_split: a string
    autostack: boolean, see `get_estimator` docstring for details.
    eval_checkpoint_step: int, list of ints, or None, see `eval_model` doc
      string for details.
    export_checkpoint_step: int or None, see `export_model` doc string for
      details.
    export_path: a string, path to export the saved model
    mode: string, one of
      train - train the model
      eval - eval the model by decoding predictions
      score_eval - eval the model by computing log likelihood scores of targets
      perplexity_eval - eval the model by computing perplexity
      infer - decode predictions based on inputs
      score_from_dataset - compute scores of targets from a dataset
      score_from_strings - compute scores of targets from strings or a file
      export_score - export a model that scores provided examples
      export_infer - export a model that decodes predictions based on inputs
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, see `get_estimator` docstring.
    keep_checkpoint_max: an integer, see `get_estimator` docstring.
    eval_summary_dir: str, see `eval_model` docstring for details.
    batch_size: An integer or a (method, value) pair to pass to
      compute_batch_size(). Note that this is the global batch size and not the
      per-shard batch size.
    train_steps: An integer or a function with the same signature as
      auto_train_steps().  Total number of training steps in this run.
    total_run_steps: An integer, used when training is split over multiple
      runs. This value is gin-configurable and used to set the total_run_steps
      for the learning_rate_schedule.
    sequence_length: an integer or a dict from feature-key to integer
      the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}.
      May also be set to `None` in eval mode to automatically compute the
      maximum length of the examples.
    mesh_shape: an input to mtf.convert_to_shape()
    mesh_devices: a list of strings, see `get_estimator` docstring.
    layout_rules: an input to mtf.convert_to_layout_rules()
    learning_rate_schedule: a function which takes the scalar name argument
      `step` and the numeric argument `total_train_steps` and returns the scalar
      learning rate.  Alternatively a float.  Alternatively, a list of
      such factos to be multiplied together.
    optimizer: a class extending optimize.Optimizer, required for training
    predict_fn: an optional function, see `get_estimator` docstring for details.
    variable_filter: a string, see `get_estimator` docstring for details.
    perplexity_eval_steps: an integer - number of steps for perplexity eval
    init_checkpoint: a string, see `get_estimator` docstring for details.
    ensemble_inputs: an integer, see `train_model` docstring for details.
    train_model_fn: an optional train function, is `train_model` by default.
    skip_seen_data: a boolean, is `False` by default. Used when a training run
      restarts to skip already seen data. This flag is only consistent when
      every setting (such as batch size and random seed) on the model is the
      same between the original run and the new run. May require a significant
      amount of time to skip a large number of steps.
    seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
      steps from this starting point. Useful when finetuning.
    output_eval_examples: a boolean, is `True` by default. Used to decide
      whether to output whether to dump inputs, targets, and predictions of the
      eval examples in plaintext to eval_summary_dir.
    checkpoint_input_pipeline: a boolean, whether to checkpoint the input
      pipeline in order to restart from the previous run. May require a large
      amount of disk space for complicated input pipelines.
    eval_dir_suffix: a string, if not None then will be appended to the eval
      subdirectory name for all three eval modes:
      `perplexity_eval`, `eval`, `score_eval`.
  """
  if isinstance(sequence_length, int):
    sequence_length = {"inputs": sequence_length,
                       "targets": sequence_length}

  if not isinstance(batch_size, int):
    batch_size = compute_batch_size(
        sequence_length, mesh_shape, layout_rules, batch_size)

  if not isinstance(train_steps, int):
    train_steps = train_steps(batch_size, sequence_length)

  if total_run_steps is None:
    total_run_steps = train_steps
  if isinstance(learning_rate_schedule, list):
    learning_rate_schedule = functools.partial(
        learning_rate_schedules.product_learning_rate,
        total_train_steps=total_run_steps, factors=learning_rate_schedule)

  if callable(learning_rate_schedule):
    learning_rate_schedule = functools.partial(
        learning_rate_schedule, total_train_steps=total_run_steps)

  tf.logging.info("model_type=%s", model_type,)
  tf.logging.info("mode=%s", mode,)
  tf.logging.info("sequence_length=%s", sequence_length,)
  tf.logging.info("batch_size=%s", batch_size,)
  tf.logging.info("train_steps=%s", train_steps,)
  if total_run_steps is not None:
    tf.logging.info("total_run_steps=%s", total_run_steps,)
  tf.logging.info("mesh_shape=%s", mesh_shape,)
  tf.logging.info("layout_rules=%s", layout_rules,)

  if mode == "train" and dataset_split != "train":
    raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

  if mode != "train":
    ensemble_inputs = None

  mesh_shape = mtf.convert_to_shape(mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(layout_rules)

  cluster = tf.distribute.cluster_resolver.TPUClusterResolver(
      tpu, zone=tpu_zone, project=gcp_project) if tpu else None

  tf.logging.info("Building TPUConfig with tpu_job_name=%s", tpu_job_name)

  score_in_predict_mode = "score" in mode
  estimator_fn = functools.partial(
      get_estimator,
      model_type=model_type,
      vocabulary=vocabulary,
      layout_rules=layout_rules,
      mesh_shape=mesh_shape,
      model_dir=model_dir,
      batch_size=batch_size,
      sequence_length=sequence_length,
      autostack=autostack,
      learning_rate_schedule=learning_rate_schedule,
      keep_checkpoint_max=keep_checkpoint_max,
      save_checkpoints_steps=save_checkpoints_steps,
      optimizer=optimizer,
      predict_fn=predict_fn,
      score_in_predict_mode=score_in_predict_mode,
      variable_filter=variable_filter,
      init_checkpoint=init_checkpoint,
      ensemble_inputs=ensemble_inputs,
      use_tpu=tpu,
      tpu_job_name=tpu_job_name,
      iterations_per_loop=iterations_per_loop,
      cluster=cluster,
      mesh_devices=mesh_devices)

  if mode not in ("eval", "score_eval"):
    if sequence_length is None:
      raise ValueError(f"`sequence_length` must be specified in '{mode}' mode.")
    estimator = estimator_fn()

  if mode == "train":
    # train_dataset_fn could be None if train_model_fn is not equal to
    # train_model
    if train_dataset_fn is None:
      raise ValueError("Must provide train_dataset_fn through gin")

    train_model_fn(estimator, vocabulary, sequence_length, batch_size,
                   train_dataset_fn, train_steps, ensemble_inputs,
                   skip_seen_data=skip_seen_data,
                   seen_data_init_step=seen_data_init_step,
                   checkpoint_input_pipeline=checkpoint_input_pipeline)

  elif mode == "perplexity_eval":
    if eval_dataset_fn is None:
      if train_dataset_fn is not None:
        tf.logging.warning("Using train_dataset_fn for perplexity eval")
        eval_datasets = [transformer_dataset.EvalDataset(
            name="eval",
            dataset_fn=functools.partial(train_dataset_fn,
                                         sequence_length=sequence_length,
                                         vocabulary=vocabulary,
                                         dataset_split=dataset_split),
            postprocess_fn=None,
            metric_fns=None)]
      else:
        raise ValueError(
            "for perplexity_eval, "
            "must provide one of eval_dataset_fn and train_dataset_fn")
    else:
      eval_datasets = eval_dataset_fn(
          sequence_length=sequence_length,
          vocabulary=vocabulary,
          dataset_split=dataset_split,
      )
    def _input_fn(params, eval_dataset):
      del params
      ds = eval_dataset.dataset_fn().map(
          _filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      ds = transformer_dataset.pad_dataset_with_zeroed_out_examples(ds)
      ds = (ds.batch(batch_size * (ensemble_inputs or 1), drop_remainder=True)
            .prefetch(tf.data.experimental.AUTOTUNE))
      return ds
    checkpoint_paths = get_checkpoint_iterator(eval_checkpoint_step, model_dir)
    for checkpoint_path in checkpoint_paths:
      for eval_dataset in eval_datasets:
        tf.random.set_random_seed(12345)
        random.seed(12345)
        num_examples = batch_size * perplexity_eval_steps
        # include the number of examples in the evaluation name so as to
        # make sure we are comparing apples to apples.
        name = "%s_%s_%d" % (eval_dataset.name, dataset_split, num_examples)
        if eval_dir_suffix is not None:
          name += "_%s" % eval_dir_suffix
        _ = estimator.evaluate(
            input_fn=functools.partial(_input_fn, eval_dataset=eval_dataset),
            steps=perplexity_eval_steps,
            checkpoint_path=checkpoint_path,
            name=name)
  elif mode in ("eval", "score_eval"):
    eval_model(
        estimator_fn,
        vocabulary,
        sequence_length,
        batch_size,
        dataset_split,
        model_dir,
        eval_dataset_fn,
        eval_summary_dir,
        eval_checkpoint_step,
        eval_with_score=(mode == "score_eval"),
        output_eval_examples=output_eval_examples,
        eval_dir_suffix=eval_dir_suffix)
  elif mode == "infer":
    infer_model(estimator, vocabulary, sequence_length, batch_size, model_type,
                model_dir, eval_checkpoint_step)
  elif mode == "score_from_strings":
    score_from_strings(estimator=estimator,
                       vocabulary=vocabulary,
                       model_type=model_type,
                       batch_size=batch_size,
                       sequence_length=sequence_length,
                       model_dir=model_dir,
                       eval_checkpoint_step=eval_checkpoint_step)
  elif mode == "score_from_dataset":
    score_from_dataset(estimator, vocabulary, batch_size, sequence_length,
                       model_dir, eval_checkpoint_step, dataset_split)
  elif mode in ["export_score", "export_infer", "export"]:
    if mode == "export":
      tf.logging.warning("Mode 'export' is deprecated. "
                         "Defaulting to 'export_infer'.")
    if export_checkpoint_step:
      checkpoint_path = get_checkpoint_iterator(
          export_checkpoint_step, model_dir)
      if isinstance(checkpoint_path, list):
        checkpoint_path = checkpoint_path[0]
      else:
        checkpoint_path = next(checkpoint_path)
    else:
      # Use the latest checkpoint in the model directory.
      checkpoint_path = None
    export_model(estimator, export_path, vocabulary, sequence_length,
                 model_type, score_in_predict_mode, batch_size, checkpoint_path)

  else:
    raise ValueError(
        "unknown mode %s - must be train/perplexity_eval/eval/infer/export"
        % mode)