def Run()

in perfkitbenchmarker/linux_benchmarks/mnist_benchmark.py [0:0]


def Run(benchmark_spec):
  """Run MNIST on the cluster.

  Args:
    benchmark_spec: The benchmark specification. Contains all data that is
      required to run the benchmark.

  Returns:
    A list of sample.Sample objects.
  """
  _UpdateBenchmarkSpecWithFlags(benchmark_spec)
  vm = benchmark_spec.vms[0]

  if benchmark_spec.tpus:
    mnist_benchmark_script = 'mnist_tpu.py'
    mnist_benchmark_cmd = (
        'cd tpu/models && '
        'export PYTHONPATH=$(pwd) && '
        'cd official/mnist && '
        'python {script} '
        '--data_dir={data_dir} '
        '--iterations={iterations} '
        '--model_dir={model_dir} '
        '--batch_size={batch_size}'.format(
            script=mnist_benchmark_script,
            data_dir=benchmark_spec.data_dir,
            iterations=benchmark_spec.iterations,
            model_dir=benchmark_spec.model_dir,
            batch_size=benchmark_spec.batch_size,
        )
    )
  else:
    mnist_benchmark_script = 'mnist.py'
    mnist_benchmark_cmd = (
        'cd models && '
        'export PYTHONPATH=$(pwd) && '
        'cd official/mnist && '
        'python {script} '
        '--data_dir={data_dir} '
        '--model_dir={model_dir} '
        '--batch_size={batch_size} '.format(
            script=mnist_benchmark_script,
            data_dir=benchmark_spec.data_dir,
            model_dir=benchmark_spec.model_dir,
            batch_size=benchmark_spec.batch_size,
        )
    )

  if nvidia_driver.CheckNvidiaGpuExists(vm):
    mnist_benchmark_cmd = '{env} {cmd}'.format(
        env=tensorflow.GetEnvironmentVars(vm), cmd=mnist_benchmark_cmd
    )
  samples = []
  metadata = CreateMetadataDict(benchmark_spec)

  if benchmark_spec.train_steps > 0:
    if benchmark_spec.tpus:
      tpu = benchmark_spec.tpu_groups['train'].GetName()
      num_shards = '--num_shards={}'.format(
          benchmark_spec.tpu_groups['train'].GetNumShards()
      )
    else:
      tpu = num_shards = ''

    if benchmark_spec.tpus:
      mnist_benchmark_train_cmd = (
          '{cmd} --tpu={tpu} --use_tpu={use_tpu} --train_steps={train_steps} '
          '{num_shards} --noenable_predict'.format(
              cmd=mnist_benchmark_cmd,
              tpu=tpu,
              use_tpu=bool(benchmark_spec.tpus),
              train_steps=benchmark_spec.train_steps,
              num_shards=num_shards,
          )
      )
    else:
      mnist_benchmark_train_cmd = '{cmd} --train_epochs={train_epochs} '.format(
          cmd=mnist_benchmark_cmd, train_epochs=benchmark_spec.train_epochs
      )

    start = time.time()
    stdout, stderr = vm.RobustRemoteCommand(mnist_benchmark_train_cmd)
    elapsed_seconds = time.time() - start
    samples.extend(
        MakeSamplesFromTrainOutput(
            metadata,
            stdout + stderr,
            elapsed_seconds,
            benchmark_spec.train_steps,
        )
    )

  if benchmark_spec.eval_steps > 0:
    if benchmark_spec.tpus:
      mnist_benchmark_eval_cmd = (
          '{cmd} --tpu={tpu} --use_tpu={use_tpu} --eval_steps={eval_steps}'
          .format(
              cmd=mnist_benchmark_cmd,
              use_tpu=bool(benchmark_spec.tpus),
              tpu=benchmark_spec.tpu_groups['eval'].GetName(),
              eval_steps=benchmark_spec.eval_steps,
          )
      )
    else:
      mnist_benchmark_eval_cmd = '{cmd} --eval_steps={eval_steps}'.format(
          cmd=mnist_benchmark_cmd, eval_steps=benchmark_spec.eval_steps
      )

    stdout, stderr = vm.RobustRemoteCommand(mnist_benchmark_eval_cmd)
    samples.extend(
        MakeSamplesFromEvalOutput(metadata, stdout + stderr, elapsed_seconds)
    )
  return samples