# Copyright 2019 PerfKitBenchmarker Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Runs benchmarks in PerfKitBenchmarker.

All benchmarks in PerfKitBenchmarker export the following interface:

GetConfig: this returns, the name of the benchmark, the number of machines
         required to run one instance of the benchmark, a detailed description
         of the benchmark, and if the benchmark requires a scratch disk.
Prepare: this function takes a list of VMs as an input parameter. The benchmark
         will then get all binaries required to run the benchmark and, if
         required, create data files.
Run: this function takes a list of VMs as an input parameter. The benchmark will
     then run the benchmark upon the machines specified. The function will
     return a dictonary containing the results of the benchmark.
Cleanup: this function takes a list of VMs as an input parameter. The benchmark
         will then return the machine to the state it was at before Prepare
         was called.

PerfKitBenchmarker has the following run stages: provision, prepare,
    run, cleanup, teardown, and all.

provision: Read command-line flags, decide what benchmarks to run, and
    create the necessary resources for each benchmark, including
    networks, VMs, disks, and keys, and generate a run_uri, which can
    be used to resume execution at later stages.
prepare: Execute the Prepare function of each benchmark to install
         necessary software, upload datafiles, etc.
run: Execute the Run function of each benchmark and collect the
     generated samples. The publisher may publish these samples
     according to PKB's settings. The Run stage can be called multiple
     times with the run_uri generated by the provision stage.
cleanup: Execute the Cleanup function of each benchmark to uninstall
         software and delete data files.
teardown: Delete VMs, key files, networks, and disks created in the
    'provision' stage.

all: PerfKitBenchmarker will run all of the above stages (provision,
     prepare, run, cleanup, teardown). Any resources generated in the
     provision stage will be automatically deleted in the teardown
     stage, even if there is an error in an earlier stage. When PKB is
     running in this mode, the run cannot be repeated or resumed using
     the run_uri.
"""


import collections
from collections.abc import Mapping, MutableSequence
import copy
import itertools
import json
import logging
import multiprocessing
from os.path import isfile
import pickle
import random
import re
import sys
import threading
import time
import types
from typing import Any, Collection, Dict, List, Sequence, Set, Tuple, Type
import uuid

from absl import flags
from perfkitbenchmarker import archive
from perfkitbenchmarker import background_tasks
from perfkitbenchmarker import benchmark_lookup
from perfkitbenchmarker import benchmark_sets
from perfkitbenchmarker import benchmark_spec as bm_spec
from perfkitbenchmarker import benchmark_status
from perfkitbenchmarker import configs
from perfkitbenchmarker import context
from perfkitbenchmarker import errors
from perfkitbenchmarker import events
from perfkitbenchmarker import flag_alias
from perfkitbenchmarker import flag_util
from perfkitbenchmarker import flags as pkb_flags
from perfkitbenchmarker import linux_benchmarks
from perfkitbenchmarker import linux_virtual_machine
from perfkitbenchmarker import log_util
from perfkitbenchmarker import os_types
from perfkitbenchmarker import package_lookup
from perfkitbenchmarker import providers
from perfkitbenchmarker import publisher
from perfkitbenchmarker import requirements
from perfkitbenchmarker import sample
from perfkitbenchmarker import stages
from perfkitbenchmarker import static_virtual_machine
from perfkitbenchmarker import time_triggers
from perfkitbenchmarker import timing_util
from perfkitbenchmarker import traces
from perfkitbenchmarker import version
from perfkitbenchmarker import virtual_machine
from perfkitbenchmarker import vm_util
from perfkitbenchmarker import windows_benchmarks
from perfkitbenchmarker.configs import benchmark_config_spec
from perfkitbenchmarker.linux_benchmarks import cluster_boot_benchmark
from perfkitbenchmarker.linux_benchmarks import cuda_memcpy_benchmark
from perfkitbenchmarker.linux_packages import build_tools

# Add additional flags to ./flags.py
# Keeping this flag here rather than flags.py to avoid a circular dependency
# on benchmark_status.
_RETRY_SUBSTATUSES = flags.DEFINE_multi_enum(
    'retry_substatuses',
    benchmark_status.FailedSubstatus.RETRYABLE_SUBSTATUSES,
    benchmark_status.FailedSubstatus.RETRYABLE_SUBSTATUSES,
    'The failure substatuses to retry on. By default, failed runs are run with '
    'the same previous config.',
)

COMPLETION_STATUS_FILE_NAME = 'completion_statuses.json'
REQUIRED_INFO = ['scratch_disk', 'num_machines']
REQUIRED_EXECUTABLES = frozenset(['ssh', 'ssh-keygen', 'scp', 'openssl'])
MAX_RUN_URI_LENGTH = 12
FLAGS = flags.FLAGS

# Define patterns for help text processing.
BASE_RELATIVE = (  # Relative path from markdown output to PKB home for link writing.
    '../'
)
MODULE_REGEX = r'^\s+?(.*?):.*'  # Pattern that matches module names.
FLAGS_REGEX = r'(^\s\s--.*?(?=^\s\s--|\Z))+?'  # Pattern that matches each flag.
FLAGNAME_REGEX = (  # Pattern that matches flag name in each flag.
    r'^\s+?(--.*?)(:.*\Z)'
)
DOCSTRING_REGEX = (  # Pattern that matches triple quoted comments.
    r'"""(.*?|$)"""'
)

_TEARDOWN_EVENT = multiprocessing.Event()
_ANY_ZONE = 'any'

events.register_tracers.connect(traces.RegisterAll)
events.register_tracers.connect(time_triggers.RegisterAll)


@flags.multi_flags_validator(
    ['smart_quota_retry', 'smart_capacity_retry', 'retries', 'zones', 'zone'],
    message=(
        'Smart zone retries requires exactly one single zone from --zones '
        'or --zone, as well as retry count > 0.'
    ),
)
def ValidateSmartZoneRetryFlags(flags_dict):
  """Validates smart zone retry flags."""
  if flags_dict['smart_quota_retry'] or flags_dict['smart_capacity_retry']:
    if flags_dict['retries'] == 0:
      return False
    return (len(flags_dict['zones']) == 1 and not flags_dict['zone']) or (
        len(flags_dict['zone']) == 1 and not flags_dict['zones']
    )
  return True


@flags.multi_flags_validator(
    ['retries', 'run_stage'],
    message='Retries requires running all stages of the benchmark.',
)
def ValidateRetriesAndRunStages(flags_dict):
  if flags_dict['retries'] > 0 and flags_dict['run_stage'] != stages.STAGES:
    return False
  return True


def ParseSkipTeardownConditions(
    skip_teardown_conditions: Collection[str],
) -> Mapping[str, Mapping[str, float | None]]:
  """Parses the skip_teardown_conditions flag.

  Used by the validator below and flag_util.ShouldTeardown to separate
  conditions passed by the --skip_teardown_conditions flag into three tokens:
      metric, lower bound, upper_bound

  Initial regex parsing captures a metric (any string before a > or <),
  direction (the > or <), and a threshold (any number after the direction).

  Args:
    skip_teardown_conditions: list of conditions to parse

  Returns:
    list of tuples of (metric, lower_bound, upper_bound)
  Raises:
    ValueError: if any condition is invalid
  """
  parsed_conditions = {}
  pattern = re.compile(
      r"""
      ([\w -]+)   # Matches all characters that could appear in a metric name
      ([<>])      # Matches < or >
      ([\d+\.]+)  # Matches any floating point number
      """,
      re.VERBOSE,
  )
  for condition in skip_teardown_conditions:
    match = pattern.match(condition)
    if not match or len(match.groups()) != 3:
      raise ValueError(
          'Invalid skip_teardown_conditions flag. Conditions must be in the '
          'format of:\n'
          '<metric><direction><threshold>;...;...\n'
          'where metric is any string, direction is either > or <, and '
          'threshold is any number.'
      )
    metric, direction, threshold = match.groups()
    # Raises ValueError if threshold is not a valid number.
    threshold = float(threshold)
    lower_bound = threshold if direction == '>' else None
    upper_bound = threshold if direction == '<' else None
    if metric not in parsed_conditions:
      parsed_conditions[metric] = {
          'lower_bound': lower_bound,
          'upper_bound': upper_bound,
      }
      continue
    # Update the existing metric's bound(s) if necessary.
    current_lower_bound = parsed_conditions[metric]['lower_bound']
    if lower_bound is not None and (
        current_lower_bound is None or lower_bound < current_lower_bound
    ):
      parsed_conditions[metric]['lower_bound'] = lower_bound
    current_upper_bound = parsed_conditions[metric]['upper_bound']
    if upper_bound is not None and (
        current_upper_bound is None or upper_bound > current_upper_bound
    ):
      parsed_conditions[metric]['upper_bound'] = upper_bound
  return parsed_conditions


@flags.validator(
    'skip_teardown_conditions',
    message='Invalid skip_teardown_conditions flag.',
)
def ValidateSkipTeardownConditions(flags_dict: Mapping[str, Any]) -> bool:
  """Validates skip_teardown_conditions flag."""
  if 'skip_teardown_conditions' not in flags_dict:
    return True
  try:
    ParseSkipTeardownConditions(flags_dict['skip_teardown_conditions'])
    return True
  except ValueError:
    return False


def MetricMeetsConditions(
    metric_sample: Mapping[str, Any],
    conditions: Mapping[str, Mapping[str, float | None]],
) -> bool:
  """Checks if a metric sample meets any conditions.

  If a metric falls within the bounds of a condition, log the metric and the
  condition.

  Args:
    metric_sample: The metric sample to check
    conditions: The conditions to check against

  Returns:
    True if the metric sample meets any of the conditions, False otherwise.
  """
  if metric_sample['metric'] not in conditions:
    return False

  target_condition = conditions[metric_sample['metric']]
  lower_bound = target_condition['lower_bound']
  upper_bound = target_condition['upper_bound']
  lower_bound_satisfied = (
      lower_bound is not None and metric_sample['value'] > lower_bound
  )
  upper_bound_satisfied = (
      upper_bound is not None and metric_sample['value'] < upper_bound
  )
  if lower_bound_satisfied and upper_bound_satisfied:
    logging.info(
        'Skip teardown condition met: %s is greater than %s %s and less'
        ' than %s %s',
        metric_sample['metric'],
        lower_bound,
        metric_sample['unit'],
        upper_bound,
        metric_sample['unit'],
    )
    return True
  # Requires that a metric meet both thresholds if lower_bound < upper_bound.
  elif (
      lower_bound is not None
      and upper_bound is not None
      and lower_bound < upper_bound
  ):
    return False
  elif lower_bound_satisfied:
    logging.info(
        'Skip teardown condition met: %s is greater than %s %s',
        metric_sample['metric'],
        lower_bound,
        metric_sample['unit'],
    )
    return True
  elif upper_bound_satisfied:
    logging.info(
        'Skip teardown condition met: %s is less than %s %s',
        metric_sample['metric'],
        upper_bound,
        metric_sample['unit'],
    )
    return True
  return False


def ShouldTeardown(
    skip_teardown_conditions: Mapping[str, Mapping[str, float | None]],
    samples: MutableSequence[Mapping[str, Any]],
    vms: Sequence[virtual_machine.BaseVirtualMachine] | None = None,
    skip_teardown_zonal_vm_limit: int | None = None,
    skip_teardown_on_command_timeout: bool = False,
) -> bool:
  """Checks all samples against all skip teardown conditions.

  Args:
    skip_teardown_conditions: list of tuples of: (metric, lower_bound,
      upper_bound)
    samples: list of samples to check against the conditions
    vms: list of VMs brought up by the benchmark
    skip_teardown_zonal_vm_limit: the maximum number of VMs in the zone that can
      be left behind.
    skip_teardown_on_command_timeout: a boolean indicating whether to skip
      teardown if the failure substatus is COMMAND_TIMEOUT

  Returns:
    True if the benchmark should teardown as usual, False if it should skip due
    to a condition being met.
  """
  if not skip_teardown_conditions and not skip_teardown_on_command_timeout:
    return True
  if skip_teardown_on_command_timeout:
    for status_sample in samples:
      if (
          status_sample['metadata'].get('failed_substatus')
          == benchmark_status.FailedSubstatus.COMMAND_TIMEOUT
      ):
        logging.warning(
            'Skipping TEARDOWN phase due to COMMAND_TIMEOUT substatus.'
        )
        return False
  if skip_teardown_zonal_vm_limit:
    for vm in vms:
      num_lingering_vms = vm.GetNumTeardownSkippedVms()
      if (
          num_lingering_vms is not None
          and num_lingering_vms + len(vms) > skip_teardown_zonal_vm_limit
      ):
        logging.warning(
            'Too many lingering VMs: tearing down resources regardless of skip'
            ' teardown conditions.'
        )
        return True
  for metric_sample in samples:
    if MetricMeetsConditions(metric_sample, skip_teardown_conditions):
      logging.warning('Skipping TEARDOWN phase.')
      return False
  return True


def _InjectBenchmarkInfoIntoDocumentation():
  """Appends each benchmark's information to the main module's docstring."""
  # TODO: Verify if there is other way of appending additional help
  # message.
  # Inject more help documentation
  # The following appends descriptions of the benchmarks and descriptions of
  # the benchmark sets to the help text.
  benchmark_sets_list = [
      '%s:  %s' % (set_name, benchmark_sets.BENCHMARK_SETS[set_name]['message'])
      for set_name in benchmark_sets.BENCHMARK_SETS
  ]
  sys.modules['__main__'].__doc__ = (
      'PerfKitBenchmarker version: {version}\n\n{doc}\n'
      'Benchmarks (default requirements):\n'
      '\t{benchmark_doc}'
  ).format(
      version=version.VERSION,
      doc=__doc__,
      benchmark_doc=_GenerateBenchmarkDocumentation(),
  )
  sys.modules['__main__'].__doc__ += '\n\nBenchmark Sets:\n\t%s' % '\n\t'.join(
      benchmark_sets_list
  )


def _ParseFlags(argv):
  """Parses the command-line flags."""
  try:
    argv = FLAGS(argv)
  except flags.Error as e:
    logging.error(e)
    logging.info('For usage instructions, use --helpmatch={module_name}')
    logging.info('For example, ./pkb.py --helpmatch=benchmarks.fio')
    sys.exit(1)


def _PrintHelp(matches=None):
  """Prints help for flags defined in matching modules.

  Args:
    matches: regex string or None. Filters help to only those whose name matched
      the regex. If None then all flags are printed.
  """
  if not matches:
    print(FLAGS)
  else:
    flags_by_module = FLAGS.flags_by_module_dict()
    modules = sorted(flags_by_module)
    regex = re.compile(matches)
    for module_name in modules:
      if regex.search(module_name):
        print(FLAGS.module_help(module_name))


def _PrintHelpMD(matches=None):
  """Prints markdown formatted help for flags defined in matching modules.

  Works just like --helpmatch.

  Args:
    matches: regex string or None. Filters help to only those whose name matched
      the regex. If None then all flags are printed.

  Raises:
    RuntimeError: If unable to find module help.
  Eg:
  * all flags: `./pkb.py --helpmatchmd .*`  > testsuite_docs/all.md
  * linux benchmarks: `./pkb.py --helpmatchmd linux_benchmarks.*`  >
    testsuite_docs/linux_benchmarks.md  * specific modules `./pkb.py
    --helpmatchmd iperf`  > testsuite_docs/iperf.md  * windows packages
    `./pkb.py --helpmatchmd windows_packages.*`  >
    testsuite_docs/windows_packages.md
  * GCP provider: `./pkb.py --helpmatchmd providers.gcp.* >
    testsuite_docs/providers_gcp.md`
  """

  flags_by_module = FLAGS.flags_by_module_dict()
  modules = sorted(flags_by_module)
  regex = re.compile(matches)
  for module_name in modules:
    if regex.search(module_name):
      # Compile regex patterns.
      module_regex = re.compile(MODULE_REGEX)
      flags_regex = re.compile(FLAGS_REGEX, re.MULTILINE | re.DOTALL)
      flagname_regex = re.compile(FLAGNAME_REGEX, re.MULTILINE | re.DOTALL)
      docstring_regex = re.compile(DOCSTRING_REGEX, re.MULTILINE | re.DOTALL)
      # Retrieve the helpmatch text to format.
      helptext_raw = FLAGS.module_help(module_name)

      # Converts module name to github linkable string.
      # eg: perfkitbenchmarker.linux_benchmarks.iperf_vpn_benchmark ->
      # perfkitbenchmarker/linux_benchmarks/iperf_vpn_benchmark.py
      match = re.search(
          module_regex,
          helptext_raw,
      )
      if not match:
        raise RuntimeError(
            f'Unable to find "{module_regex}" in "{helptext_raw}"'
        )
      module = match.group(1)
      module_link = module.replace('.', '/') + '.py'
      # Put flag name in a markdown code block for visibility.
      flags = re.findall(flags_regex, helptext_raw)
      flags[:] = [flagname_regex.sub(r'`\1`\2', flag) for flag in flags]
      # Get the docstring for the module without importing everything into our
      # namespace. Probably a better way to do this
      docstring = 'No description available'
      # Only pull doststrings from inside pkb source files.
      if isfile(module_link):
        with open(module_link) as f:
          source = f.read()
          # Get the triple quoted matches.
          docstring_match = re.search(docstring_regex, source)
          # Some modules don't have docstrings.
          # eg perfkitbenchmarker/providers/alicloud/flags.py
          if docstring_match is not None:
            docstring = docstring_match.group(1)
      # Format output and print here.
      if isfile(module_link):  # Only print links for modules we can find.
        print('### [' + module, '](' + BASE_RELATIVE + module_link + ')\n')
      else:
        print('### ' + module + '\n')
      print('#### Description:\n\n' + docstring + '\n\n#### Flags:\n')
      print('\n'.join(flags) + '\n')


def CheckVersionFlag():
  """If the --version flag was specified, prints the version and exits."""
  if FLAGS.version:
    print(version.VERSION)
    sys.exit(0)


def _InitializeRunUri():
  """Determines the PKB run URI and sets FLAGS.run_uri."""
  if FLAGS.run_uri is None:
    if stages.PROVISION in FLAGS.run_stage:
      FLAGS.run_uri = str(uuid.uuid4())[-8:]
    else:
      # Attempt to get the last modified run directory.
      run_uri = vm_util.GetLastRunUri()
      if run_uri:
        FLAGS.run_uri = run_uri
        logging.warning(
            'No run_uri specified. Attempting to run the following stages with '
            '--run_uri=%s: %s',
            FLAGS.run_uri,
            ', '.join(FLAGS.run_stage),
        )
      else:
        raise errors.Setup.NoRunURIError(
            'No run_uri specified. Could not run the following stages: %s'
            % ', '.join(FLAGS.run_stage)
        )
  elif not FLAGS.run_uri.isalnum() or len(FLAGS.run_uri) > MAX_RUN_URI_LENGTH:
    raise errors.Setup.BadRunURIError(
        'run_uri must be alphanumeric and less '
        'than or equal to %d characters in '
        'length.' % MAX_RUN_URI_LENGTH
    )


def _CreateBenchmarkSpecs():
  """Create a list of BenchmarkSpecs for each benchmark run to be scheduled.

  Returns:
    A list of BenchmarkSpecs.
  """
  specs = []
  benchmark_tuple_list = benchmark_sets.GetBenchmarksFromFlags()
  benchmark_counts = collections.defaultdict(itertools.count)
  for benchmark_module, user_config in benchmark_tuple_list:
    # Construct benchmark config object.
    name = benchmark_module.BENCHMARK_NAME
    # This expected_os_type check seems rather unnecessary.
    expected_os_types = os_types.ALL
    with flag_util.OverrideFlags(FLAGS, user_config.get('flags')):
      config_dict = benchmark_module.GetConfig(user_config)
    config_spec_class = getattr(
        benchmark_module,
        'BENCHMARK_CONFIG_SPEC_CLASS',
        benchmark_config_spec.BenchmarkConfigSpec,
    )
    config = config_spec_class(
        name,
        expected_os_types=expected_os_types,
        flag_values=FLAGS,
        **config_dict,
    )

    # Assign a unique ID to each benchmark run. This differs even between two
    # runs of the same benchmark within a single PKB run.
    uid = name + str(next(benchmark_counts[name]))

    # Optional step to check flag values and verify files exist.
    check_prereqs = getattr(benchmark_module, 'CheckPrerequisites', None)
    if check_prereqs:
      try:
        with config.RedirectFlags(FLAGS):
          check_prereqs(config)
      except:
        logging.exception('Prerequisite check failed for %s', name)
        raise

    with config.RedirectFlags(FLAGS):
      specs.append(
          bm_spec.BenchmarkSpec.GetBenchmarkSpec(benchmark_module, config, uid)
      )

  return specs


def _WriteCompletionStatusFile(benchmark_specs, status_file):
  """Writes a completion status file.

  The file has one json object per line, each with the following format:

  {
    "name": <benchmark name>,
    "status": <completion status>,
    "failed_substatus": <failed substatus>,
    "status_detail": <descriptive string (if present)>,
    "flags": <flags dictionary>
  }

  Args:
    benchmark_specs: The list of BenchmarkSpecs that ran.
    status_file: The file object to write the json structures to.
  """
  for spec in benchmark_specs:
    # OrderedDict so that we preserve key order in json file
    status_dict = collections.OrderedDict()
    status_dict['name'] = spec.name
    status_dict['status'] = spec.status
    if spec.failed_substatus:
      status_dict['failed_substatus'] = spec.failed_substatus
    if spec.status_detail:
      status_dict['status_detail'] = spec.status_detail
    status_dict['flags'] = spec.config.flags
    # Record freeze and restore path values.
    if pkb_flags.FREEZE_PATH.value:
      status_dict['flags']['freeze'] = pkb_flags.FREEZE_PATH.value
    if pkb_flags.RESTORE_PATH.value:
      status_dict['flags']['restore'] = pkb_flags.RESTORE_PATH.value
    status_file.write(json.dumps(status_dict) + '\n')


def _SetRestoreSpec(spec: bm_spec.BenchmarkSpec) -> None:
  """Unpickles the spec to restore resources from, if provided."""
  restore_path = pkb_flags.RESTORE_PATH.value
  if restore_path:
    logging.info('Using restore spec at path: %s', restore_path)
    with open(restore_path, 'rb') as spec_file:
      spec.restore_spec = pickle.load(spec_file)


def _SetFreezePath(spec: bm_spec.BenchmarkSpec) -> None:
  """Sets the path to freeze resources to if provided."""
  if pkb_flags.FREEZE_PATH.value:
    spec.freeze_path = pkb_flags.FREEZE_PATH.value
    logging.info('Using freeze path, %s', spec.freeze_path)


def DoProvisionPhase(
    spec: bm_spec.BenchmarkSpec, timer: timing_util.IntervalTimer
):
  """Performs the Provision phase of benchmark execution.

  Args:
    spec: The BenchmarkSpec created for the benchmark.
    timer: An IntervalTimer that measures the start and stop times of resource
      provisioning.
  """
  logging.info('Provisioning resources for benchmark %s', spec.name)
  events.before_phase.send(stages.PROVISION, benchmark_spec=spec)
  spec.ConstructResources()

  spec.CheckPrerequisites()

  # Pickle the spec before we try to create anything so we can clean
  # everything up on a second run if something goes wrong.
  spec.Pickle()

  events.register_tracers.send(parsed_flags=FLAGS)
  events.benchmark_start.send(benchmark_spec=spec)
  try:
    with timer.Measure('Resource Provisioning'):
      spec.Provision()
  finally:
    # Also pickle the spec after the resources are created so that
    # we have a record of things like AWS ids. Otherwise we won't
    # be able to clean them up on a subsequent run.
    spec.Pickle()
  events.after_phase.send(stages.PROVISION, benchmark_spec=spec)


class InterruptChecker:
  """An class that check interrupt on VM."""

  def __init__(self, vms):
    """Start check interrupt thread.

    Args:
      vms: A list of virtual machines.
    """
    self.vms = vms
    self.check_threads = []
    self.phase_status = threading.Event()
    for vm in vms:
      if vm.IsInterruptible():
        check_thread = threading.Thread(target=self.CheckInterrupt, args=(vm,))
        check_thread.start()
        self.check_threads.append(check_thread)

  def CheckInterrupt(self, vm):
    """Check interrupt.

    Args:
      vm: the virtual machine object.

    Returns:
      None
    """
    while not self.phase_status.is_set():
      vm.UpdateInterruptibleVmStatus(use_api=False)
      if vm.WasInterrupted():
        return
      else:
        self.phase_status.wait(vm.GetInterruptableStatusPollSeconds())

  def EndCheckInterruptThread(self):
    """End check interrupt thread."""
    self.phase_status.set()

    for check_thread in self.check_threads:
      check_thread.join()

  def EndCheckInterruptThreadAndRaiseError(self):
    """End check interrupt thread and raise error.

    Raises:
      InsufficientCapacityCloudFailure when it catches interrupt.

    Returns:
      None
    """
    self.EndCheckInterruptThread()
    if any(vm.IsInterruptible() and vm.WasInterrupted() for vm in self.vms):
      raise errors.Benchmarks.InsufficientCapacityCloudFailure('Interrupt')


def DoPreparePhase(
    spec: bm_spec.BenchmarkSpec, timer: timing_util.IntervalTimer
):
  """Performs the Prepare phase of benchmark execution.

  Args:
    spec: The BenchmarkSpec created for the benchmark.
    timer: An IntervalTimer that measures the start and stop times of the
      benchmark module's Prepare function.
  """
  logging.info('Preparing benchmark %s', spec.name)
  events.before_phase.send(stages.PREPARE, benchmark_spec=spec)
  with timer.Measure('BenchmarkSpec Prepare'):
    spec.Prepare()
  with timer.Measure('Benchmark Prepare'):
    spec.BenchmarkPrepare(spec)
  spec.StartBackgroundWorkload()
  if FLAGS.after_prepare_sleep_time:
    logging.info(
        'Sleeping for %s seconds after the prepare phase.',
        FLAGS.after_prepare_sleep_time,
    )
    time.sleep(FLAGS.after_prepare_sleep_time)
  events.after_phase.send(stages.PREPARE, benchmark_spec=spec)


def DoRunPhase(
    spec: bm_spec.BenchmarkSpec,
    collector: publisher.SampleCollector,
    timer: timing_util.IntervalTimer,
):
  """Performs the Run phase of benchmark execution.

  Args:
    spec: The BenchmarkSpec created for the benchmark.
    collector: The SampleCollector object to add samples to.
    timer: An IntervalTimer that measures the start and stop times of the
      benchmark module's Run function.
  """
  if FLAGS.before_run_pause:
    input('Hit enter to begin Run.')
  deadline = time.time() + FLAGS.run_stage_time
  run_number = 0
  consecutive_failures = 0
  last_publish_time = time.time()

  def _IsRunStageFinished():
    if FLAGS.run_stage_time > 0:
      return time.time() > deadline
    else:
      return run_number >= FLAGS.run_stage_iterations

  while True:
    samples = []
    logging.info('Running benchmark %s', spec.name)
    events.before_phase.send(stages.RUN, benchmark_spec=spec)
    events.trigger_phase.send()
    try:
      with timer.Measure('Benchmark Run'):
        samples = spec.BenchmarkRun(spec)
    except Exception:
      consecutive_failures += 1
      if consecutive_failures > FLAGS.run_stage_retries:
        raise
      logging.exception(
          'Run failed (consecutive_failures=%s); retrying.',
          consecutive_failures,
      )
    else:
      consecutive_failures = 0
    finally:
      events.after_phase.send(stages.RUN, benchmark_spec=spec)
    if FLAGS.run_stage_time or FLAGS.run_stage_iterations:
      for s in samples:
        s.metadata['run_number'] = run_number

    # Add boot time metrics on the first run iteration.
    if run_number == 0 and (
        FLAGS.boot_samples or spec.name == cluster_boot_benchmark.BENCHMARK_NAME
    ):
      samples.extend(cluster_boot_benchmark.GetTimeToBoot(spec.vms))

    # In order to collect GPU samples one of the VMs must have both an Nvidia
    # GPU and the nvidia-smi
    if FLAGS.gpu_samples:
      samples.extend(cuda_memcpy_benchmark.Run(spec))

    if FLAGS.record_lscpu:
      samples.extend(linux_virtual_machine.CreateLscpuSamples(spec.vms))
    if FLAGS.record_ulimit:
      samples.extend(linux_virtual_machine.CreateUlimitSamples(spec.vms))

    if pkb_flags.RECORD_PROCCPU.value:
      samples.extend(linux_virtual_machine.CreateProcCpuSamples(spec.vms))
    if FLAGS.record_cpu_vuln and run_number == 0:
      samples.extend(_CreateCpuVulnerabilitySamples(spec.vms))

    if FLAGS.record_gcc:
      samples.extend(_CreateGccSamples(spec.vms))
    if FLAGS.record_glibc:
      samples.extend(_CreateGlibcSamples(spec.vms))

    # Mark samples as restored to differentiate from non freeze/restore runs.
    if FLAGS.restore:
      for s in samples:
        s.metadata['restore'] = True

    events.benchmark_samples_created.send(benchmark_spec=spec, samples=samples)
    events.all_samples_created.send(benchmark_spec=spec, samples=samples)
    collector.AddSamples(samples, spec.name, spec)
    if (
        FLAGS.publish_after_run
        and FLAGS.publish_period is not None
        and FLAGS.publish_period < (time.time() - last_publish_time)
    ):
      collector.PublishSamples()
      last_publish_time = time.time()

    if pkb_flags.BETWEEN_RUNS_SLEEP_TIME.value > 0:
      logging.info(
          'Sleeping for %s seconds after run %d.',
          FLAGS.between_runs_sleep_time,
          run_number,
      )
      time.sleep(FLAGS.between_runs_sleep_time)

    run_number += 1
    if _IsRunStageFinished():
      if FLAGS.after_run_sleep_time:
        logging.info(
            'Sleeping for %s seconds after the run phase.',
            FLAGS.after_run_sleep_time,
        )
        time.sleep(FLAGS.after_run_sleep_time)
      break


def DoCleanupPhase(
    spec: bm_spec.BenchmarkSpec, timer: timing_util.IntervalTimer
):
  """Performs the Cleanup phase of benchmark execution.

  Cleanup phase work should be delegated to spec.BenchmarkCleanup to allow
  non-PKB based cleanup if needed.

  Args:
    spec: The BenchmarkSpec created for the benchmark.
    timer: An IntervalTimer that measures the start and stop times of the
      benchmark module's Cleanup function.
  """
  if FLAGS.before_cleanup_pause:
    input('Hit enter to begin Cleanup.')
  logging.info('Cleaning up benchmark %s', spec.name)
  events.before_phase.send(stages.CLEANUP, benchmark_spec=spec)
  if (
      spec.always_call_cleanup
      or any([vm.is_static for vm in spec.vms])
      or spec.dpb_service is not None
  ):
    spec.StopBackgroundWorkload()
    with timer.Measure('Benchmark Cleanup'):
      spec.BenchmarkCleanup(spec)
  events.after_phase.send(stages.CLEANUP, benchmark_spec=spec)


def DoTeardownPhase(
    spec: bm_spec.BenchmarkSpec,
    collector: publisher.SampleCollector,
    timer: timing_util.IntervalTimer,
):
  """Performs the Teardown phase of benchmark execution.

  Teardown phase work should be delegated to spec.Delete to allow non-PKB based
  teardown if needed.

  Args:
    spec: The BenchmarkSpec created for the benchmark.
    collector: The SampleCollector object to add samples to (if collecting
      delete samples)
    timer: An IntervalTimer that measures the start and stop times of resource
      teardown.
  """
  logging.info('Tearing down resources for benchmark %s', spec.name)
  events.before_phase.send(stages.TEARDOWN, benchmark_spec=spec)

  with timer.Measure('Resource Teardown'):
    spec.Delete()

  # Add delete time metrics after metadata collected
  if pkb_flags.MEASURE_DELETE.value:
    samples = cluster_boot_benchmark.MeasureDelete(spec.vms)
    collector.AddSamples(samples, spec.name, spec)

  events.after_phase.send(stages.TEARDOWN, benchmark_spec=spec)


def _SkipPendingRunsFile():
  if FLAGS.skip_pending_runs_file and isfile(FLAGS.skip_pending_runs_file):
    logging.warning(
        '%s exists.  Skipping benchmark.', FLAGS.skip_pending_runs_file
    )
    return True
  else:
    return False


_SKIP_PENDING_RUNS_CHECKS = []


def RegisterSkipPendingRunsCheck(func):
  """Registers a function to skip pending runs.

  Args:
    func: A function which returns True if pending runs should be skipped.
  """
  _SKIP_PENDING_RUNS_CHECKS.append(func)


@events.before_phase.connect
def _PublishStageStartedSamples(
    sender: str, benchmark_spec: bm_spec.BenchmarkSpec
):
  """Publish the start of each stage."""
  if sender == stages.PROVISION and pkb_flags.CREATE_STARTED_RUN_SAMPLE.value:
    _PublishRunStartedSample(benchmark_spec)
  if pkb_flags.CREATE_STARTED_STAGE_SAMPLES.value:
    _PublishEventSample(benchmark_spec, f'{sender.capitalize()} Stage Started')


def _PublishRunStartedSample(spec):
  """Publishes a sample indicating that a run has started.

  This sample is published immediately so that there exists some metric for any
  run (even if the process dies).

  Args:
    spec: The BenchmarkSpec object with run information.
  """
  metadata = {'flags': str(flag_util.GetProvidedCommandLineFlags())}
  # Publish the path to this spec's PKB logs at the start of the runs.
  if log_util.PKB_LOG_BUCKET.value and FLAGS.run_uri:
    metadata['pkb_log_path'] = log_util.GetLogCloudPath(
        log_util.PKB_LOG_BUCKET.value, f'{FLAGS.run_uri}-pkb.log'
    )
  if log_util.VM_LOG_BUCKET.value and FLAGS.run_uri:
    metadata['vm_log_path'] = log_util.GetLogCloudPath(
        log_util.VM_LOG_BUCKET.value, FLAGS.run_uri
    )

  _PublishEventSample(spec, 'Run Started', metadata)


def _PublishEventSample(
    spec: bm_spec.BenchmarkSpec,
    event: str,
    metadata: Dict[str, Any] | None = None,
    collector: publisher.SampleCollector | None = None,
):
  """Publishes a sample indicating the progress of the benchmark.

  Value of sample is time of event in unix seconds

  Args:
    spec: The BenchmarkSpec object with run information.
    event: The progress event to publish.
    metadata: optional metadata to publish about the event.
    collector: the SampleCollector to use.
  """
  # N.B. SampleCollector seems stateless so re-using vs creating a new one seems
  # to have no effect.
  if not collector:
    collector = publisher.SampleCollector()
  collector.AddSamples(
      [sample.Sample(event, time.time(), 'seconds', metadata or {})],
      spec.name,
      spec,
  )
  collector.PublishSamples()


def _IsException(e: Exception, exception_class: Type[Exception]) -> bool:
  """Checks if the exception is of the class or contains the class name.

  When exceptions happen on on background theads (e.g. CreationInternalError on
  CreateAndBootVm) they are not propogated as exceptions to the caller, instead
  they are propagated as text inside a wrapper exception such as
  errors.VmUtil.ThreadException.

  Args:
    e: The exception instance to inspect.
    exception_class: The exception class to check if e is an instance of.

  Returns:
     true if the exception is of the class or contains the class name.
  """
  if isinstance(e, exception_class):
    return True

  if str(exception_class.__name__) in str(e):
    return True

  return False


def RunBenchmark(
    spec: bm_spec.BenchmarkSpec, collector: publisher.SampleCollector
):
  """Runs a single benchmark and adds the results to the collector.

  Args:
    spec: The BenchmarkSpec object with run information.
    collector: The SampleCollector object to add samples to.
  """

  # Since there are issues with the handling SIGINT/KeyboardInterrupt (see
  # further discussion in _BackgroundProcessTaskManager) this mechanism is
  # provided for defense in depth to force skip pending runs after SIGINT.
  for f in _SKIP_PENDING_RUNS_CHECKS:
    if f():
      logging.warning('Skipping benchmark.')
      return

  spec.status = benchmark_status.FAILED
  current_run_stage = stages.PROVISION

  # If the skip_teardown_conditions flag is set, we will check the samples
  # collected before the teardown phase to determine if we should skip teardown.
  should_teardown = True

  # Modify the logger prompt for messages logged within this function.
  label_extension = '{}({}/{})'.format(
      spec.name, spec.sequence_number, spec.total_benchmarks
  )
  context.SetThreadBenchmarkSpec(spec)
  log_context = log_util.GetThreadLogContext()
  with log_context.ExtendLabel(label_extension):
    with spec.RedirectGlobalFlags():
      end_to_end_timer = timing_util.IntervalTimer()
      detailed_timer = timing_util.IntervalTimer()
      interrupt_checker = None
      try:
        with end_to_end_timer.Measure('End to End'):
          _SetRestoreSpec(spec)
          _SetFreezePath(spec)

          if stages.PROVISION in FLAGS.run_stage:
            DoProvisionPhase(spec, detailed_timer)

          if stages.PREPARE in FLAGS.run_stage:
            current_run_stage = stages.PREPARE
            interrupt_checker = InterruptChecker(spec.vms)
            DoPreparePhase(spec, detailed_timer)
            interrupt_checker.EndCheckInterruptThreadAndRaiseError()
            interrupt_checker = None

          if stages.RUN in FLAGS.run_stage:
            current_run_stage = stages.RUN
            interrupt_checker = InterruptChecker(spec.vms)
            DoRunPhase(spec, collector, detailed_timer)
            interrupt_checker.EndCheckInterruptThreadAndRaiseError()
            interrupt_checker = None

          if stages.CLEANUP in FLAGS.run_stage:
            current_run_stage = stages.CLEANUP
            interrupt_checker = InterruptChecker(spec.vms)
            DoCleanupPhase(spec, detailed_timer)
            interrupt_checker.EndCheckInterruptThreadAndRaiseError()
            interrupt_checker = None

          if stages.TEARDOWN in FLAGS.run_stage:
            CaptureVMLogs(spec.vms)
            skip_teardown_conditions = ParseSkipTeardownConditions(
                pkb_flags.SKIP_TEARDOWN_CONDITIONS.value
            )
            should_teardown = ShouldTeardown(
                skip_teardown_conditions,
                collector.published_samples + collector.samples,
                spec.vms,
                pkb_flags.SKIP_TEARDOWN_ZONAL_VM_LIMIT.value,
                pkb_flags.SKIP_TEARDOWN_ON_COMMAND_TIMEOUT.value,
            )
            if should_teardown:
              current_run_stage = stages.TEARDOWN
              DoTeardownPhase(spec, collector, detailed_timer)
            else:
              for vm in spec.vms:
                vm.UpdateTimeoutMetadata()

        # Add timing samples.
        if (
            FLAGS.run_stage == stages.STAGES
            and timing_util.EndToEndRuntimeMeasurementEnabled()
        ):
          collector.AddSamples(
              end_to_end_timer.GenerateSamples(), spec.name, spec
          )
        if timing_util.RuntimeMeasurementsEnabled():
          collector.AddSamples(
              detailed_timer.GenerateSamples(), spec.name, spec
          )

        # Add resource related samples.
        collector.AddSamples(spec.GetSamples(), spec.name, spec)
      # except block will clean up benchmark specific resources on exception. It
      # may also clean up generic resources based on
      # FLAGS.always_teardown_on_exception.
      except (Exception, KeyboardInterrupt) as e:
        # Log specific type of failure, if known
        # TODO(dlott) Move to exception chaining with Python3 support
        if _IsException(e, errors.Benchmarks.InsufficientCapacityCloudFailure):
          spec.failed_substatus = (
              benchmark_status.FailedSubstatus.INSUFFICIENT_CAPACITY
          )
        elif _IsException(e, errors.Benchmarks.QuotaFailure):
          spec.failed_substatus = benchmark_status.FailedSubstatus.QUOTA
        elif (
            _IsException(e, errors.Benchmarks.KnownIntermittentError)
            or _IsException(e, errors.Resource.CreationInternalError)
            or _IsException(e, errors.Resource.ProvisionTimeoutError)
        ):
          spec.failed_substatus = (
              benchmark_status.FailedSubstatus.KNOWN_INTERMITTENT
          )
        elif _IsException(e, errors.Benchmarks.UnsupportedConfigError):
          spec.failed_substatus = benchmark_status.FailedSubstatus.UNSUPPORTED
        elif _IsException(e, errors.Resource.RestoreError):
          spec.failed_substatus = (
              benchmark_status.FailedSubstatus.RESTORE_FAILED
          )
        elif _IsException(e, errors.Resource.FreezeError):
          spec.failed_substatus = benchmark_status.FailedSubstatus.FREEZE_FAILED
        elif isinstance(e, KeyboardInterrupt):
          spec.failed_substatus = (
              benchmark_status.FailedSubstatus.PROCESS_KILLED
          )
        elif _IsException(e, vm_util.TimeoutExceededRetryError):
          spec.failed_substatus = (
              benchmark_status.FailedSubstatus.COMMAND_TIMEOUT
          )
        elif _IsException(e, vm_util.RetriesExceededRetryError):
          spec.failed_substatus = (
              benchmark_status.FailedSubstatus.RETRIES_EXCEEDED
          )
        elif _IsException(e, errors.Config.InvalidValue):
          spec.failed_substatus = benchmark_status.FailedSubstatus.INVALID_VALUE
        elif _IsException(e, vm_util.ImageNotFoundError):
          spec.failed_substatus = benchmark_status.FailedSubstatus.UNSUPPORTED
        else:
          spec.failed_substatus = benchmark_status.FailedSubstatus.UNCATEGORIZED
        spec.status_detail = str(e)

        # Resource cleanup (below) can take a long time. Log the error to give
        # immediate feedback, then re-throw.
        logging.exception('Error during benchmark %s', spec.name)
        if FLAGS.create_failed_run_samples:
          PublishFailedRunSample(spec, str(e), current_run_stage, collector)

        # If the particular benchmark requests us to always call cleanup, do it
        # here.
        if stages.CLEANUP in FLAGS.run_stage and spec.always_call_cleanup:
          DoCleanupPhase(spec, detailed_timer)

        if (
            FLAGS.always_teardown_on_exception
            and stages.TEARDOWN not in FLAGS.run_stage
        ):
          # Note that if TEARDOWN is specified, it will happen below.
          DoTeardownPhase(spec, collector, detailed_timer)
        raise
      # finally block will only clean up generic resources if teardown is
      # included in FLAGS.run_stage.
      finally:
        if interrupt_checker:
          interrupt_checker.EndCheckInterruptThread()
        # Deleting resources should happen first so any errors with publishing
        # don't prevent teardown.
        if stages.TEARDOWN in FLAGS.run_stage and should_teardown:
          spec.Delete()
        if FLAGS.publish_after_run:
          collector.PublishSamples()
        events.benchmark_end.send(benchmark_spec=spec)
        # Pickle spec to save final resource state.
        spec.Pickle()
  spec.status = benchmark_status.SUCCEEDED


def PublishFailedRunSample(
    spec: bm_spec.BenchmarkSpec,
    error_message: str,
    run_stage_that_failed: str,
    collector: publisher.SampleCollector,
):
  """Publish a sample.Sample representing a failed run stage.

  The sample metric will have the name 'Run Failed';
  the value will be the timestamp in Unix Seconds, and the unit will be
  'seconds'.

  The sample metadata will include the error message from the
  Exception, the run stage that failed, as well as all PKB
  command line flags that were passed in.

  Args:
    spec: benchmark_spec
    error_message: error message that was caught, resulting in the run stage
      failure.
    run_stage_that_failed: run stage that failed by raising an Exception
    collector: the collector to publish to.
  """
  # Note: currently all provided PKB command line flags are included in the
  # metadata. We may want to only include flags specific to the benchmark that
  # failed. This can be acomplished using gflag's FlagsByModuleDict().
  metadata = {
      'error_message': error_message[0 : FLAGS.failed_run_samples_error_length],
      'run_stage': run_stage_that_failed,
      'flags': str(flag_util.GetProvidedCommandLineFlags()),
  }
  background_tasks.RunThreaded(
      lambda vm: vm.UpdateInterruptibleVmStatus(use_api=True), spec.vms
  )

  interruptible_vm_count = 0
  interrupted_vm_count = 0
  vm_status_codes = []
  for vm in spec.vms:
    if vm.IsInterruptible():
      interruptible_vm_count += 1
      if vm.WasInterrupted():
        interrupted_vm_count += 1
        spec.failed_substatus = benchmark_status.FailedSubstatus.INTERRUPTED
        status_code = vm.GetVmStatusCode()
        if status_code:
          vm_status_codes.append(status_code)

  if spec.failed_substatus:
    metadata['failed_substatus'] = spec.failed_substatus

  if interruptible_vm_count:
    metadata.update({
        'interruptible_vms': interruptible_vm_count,
        'interrupted_vms': interrupted_vm_count,
        'vm_status_codes': vm_status_codes,
    })
  if interrupted_vm_count:
    logging.error(
        '%d interruptible VMs were interrupted in this failed PKB run.',
        interrupted_vm_count,
    )
  _PublishEventSample(spec, 'Run Failed', metadata, collector)


def _ShouldRetry(spec: bm_spec.BenchmarkSpec) -> bool:
  """Returns whether the benchmark run should be retried."""
  return (
      spec.status == benchmark_status.FAILED
      and spec.failed_substatus in _RETRY_SUBSTATUSES.value
  )


def _GetMachineTypes(spec: bm_spec.BenchmarkSpec) -> list[str]:
  """Returns a deduped list of machine types to provision for the given spec."""
  if FLAGS.machine_type:
    return [FLAGS.machine_type]
  results = set()
  for vm_group_spec in spec.vms_to_boot.values():
    results.add(vm_group_spec.vm_spec.machine_type)
  return sorted(list(results))


def RunBenchmarkTask(
    spec: bm_spec.BenchmarkSpec,
) -> Tuple[Sequence[bm_spec.BenchmarkSpec], List[sample.SampleDict]]:
  """Task that executes RunBenchmark.

  This is designed to be used with RunParallelProcesses. Note that
  for retries only the last run has its samples published.

  Arguments:
    spec: BenchmarkSpec. The spec to call RunBenchmark with.

  Returns:
    A BenchmarkSpec for each run iteration and a list of samples from the
    last run.
  """
  # Many providers name resources using run_uris. When running multiple
  # benchmarks in parallel, this causes name collisions on resources.
  # By modifying the run_uri, we avoid the collisions.
  if FLAGS.run_processes and FLAGS.run_processes > 1:
    spec.config.flags['run_uri'] = FLAGS.run_uri + str(spec.sequence_number)
    # Unset run_uri so the config value takes precedence.
    FLAGS['run_uri'].present = 0

  zone_retry_manager = ZoneRetryManager(_GetMachineTypes(spec))
  # Set the run count.
  max_run_count = 1 + pkb_flags.MAX_RETRIES.value

  # Useful format string for debugging.
  benchmark_info = (
      f'{spec.sequence_number}/{spec.total_benchmarks} '
      f'{spec.name} (UID: {spec.uid})'
  )

  result_specs = []
  for current_run_count in range(max_run_count):
    # Attempt to return the most recent results.
    if _TEARDOWN_EVENT.is_set():
      if result_specs and collector:
        return result_specs, collector.samples
      return [spec], []

    run_start_msg = (
        '\n'
        + '-' * 85
        + '\n'
        + 'Starting benchmark %s attempt %s of %s'
        + '\n'
        + '-' * 85
    )
    logging.info(
        run_start_msg, benchmark_info, current_run_count + 1, max_run_count
    )
    collector = publisher.SampleCollector()
    # Make a new copy of the benchmark_spec for each run since currently a
    # benchmark spec isn't compatible with multiple runs. In particular, the
    # benchmark_spec doesn't correctly allow for a provision of resources
    # after tearing down.
    spec_for_run = copy.deepcopy(spec)
    result_specs.append(spec_for_run)
    try:
      RunBenchmark(spec_for_run, collector)
    except BaseException as e:  # pylint: disable=broad-except
      logging.exception('Exception running benchmark')
      msg = f'Benchmark {benchmark_info} failed.'
      if isinstance(e, KeyboardInterrupt) or FLAGS.stop_after_benchmark_failure:
        logging.error('%s Execution will not continue.', msg)
        _TEARDOWN_EVENT.set()
        break
      logging.error('%s Execution will continue.', msg)

    # Don't retry on the last run.
    if _ShouldRetry(spec_for_run) and current_run_count != max_run_count - 1:
      logging.info(
          'Benchmark should be retried. Waiting %s seconds before running.',
          pkb_flags.RETRY_DELAY_SECONDS.value,
      )
      time.sleep(pkb_flags.RETRY_DELAY_SECONDS.value)

      # Handle smart retries if specified.
      zone_retry_manager.HandleSmartRetries(spec_for_run)

    else:
      logging.info(
          'Benchmark should not be retried. Finished %s runs of %s',
          current_run_count + 1,
          max_run_count,
      )
      break

  # We need to return both the spec and samples so that we know
  # the status of the test and can publish any samples that
  # haven't yet been published.
  return result_specs, collector.samples


class ZoneRetryManager:
  """Encapsulates state and functions for zone retries.

  Attributes:
    original_zone: If specified, the original zone provided to the benchmark.
    zones_tried: Zones that have already been tried in previous runs.
  """

  def __init__(self, machine_types: Collection[str]):
    self._CheckFlag(machine_types)
    if (
        not pkb_flags.SMART_CAPACITY_RETRY.value
        and not pkb_flags.SMART_QUOTA_RETRY.value
    ):
      return
    self._machine_types = list(machine_types)
    self._zones_tried: Set[str] = set()
    self._regions_tried: Set[str] = set()
    self._utils: types.ModuleType = providers.LoadProviderUtils(FLAGS.cloud)
    self._SetOriginalZoneAndFlag()

  def _CheckMachineTypesAreSpecified(
      self, machine_types: Collection[str]
  ) -> None:
    if not machine_types:
      raise errors.Config.MissingOption(
          'machine_type flag must be specified on the command line '
          'if zone=any feature is used.'
      )

  def _GetCurrentZoneFlag(self):
    return FLAGS[self._zone_flag].value[0]

  def _CheckFlag(self, machine_types: Collection[str]) -> None:
    for zone_flag in ['zone', 'zones']:
      if FLAGS[zone_flag].value:
        self._zone_flag = zone_flag
        if self._GetCurrentZoneFlag() == _ANY_ZONE:
          self._CheckMachineTypesAreSpecified(machine_types)
          FLAGS['smart_capacity_retry'].parse(True)
          FLAGS['smart_quota_retry'].parse(True)

  def _SetOriginalZoneAndFlag(self) -> None:
    """Records the flag name and zone value that the benchmark started with."""
    # This is guaranteed to set values due to flag validator.
    self._supported_zones = self._utils.GetZonesFromMachineType(
        self._machine_types[0]
    )
    for machine_type in self._machine_types[1:]:
      self._supported_zones.intersection_update(
          self._utils.GetZonesFromMachineType(machine_type)
      )
    if self._GetCurrentZoneFlag() == _ANY_ZONE:
      if pkb_flags.MAX_RETRIES.value < 1:
        FLAGS['retries'].parse(len(self._supported_zones))
      self._ChooseAndSetNewZone(self._supported_zones)
    self._original_zone = self._GetCurrentZoneFlag()
    self._original_region = self._utils.GetRegionFromZone(self._original_zone)

  def HandleSmartRetries(self, spec: bm_spec.BenchmarkSpec) -> None:
    """Handles smart zone retry flags if provided.

    If quota retry, pick zone in new region. If unsupported or stockout retries,
    pick zone in same region.

    Args:
      spec: benchmark spec.
    """
    if (
        pkb_flags.SMART_QUOTA_RETRY.value
        and spec.failed_substatus == benchmark_status.FailedSubstatus.QUOTA
    ):
      self._AssignZoneToNewRegion()
    elif pkb_flags.SMART_CAPACITY_RETRY.value and spec.failed_substatus in {
        benchmark_status.FailedSubstatus.UNSUPPORTED,
        benchmark_status.FailedSubstatus.INSUFFICIENT_CAPACITY,
    }:
      self._AssignZoneToSameRegion()

  def _AssignZoneToNewRegion(self) -> None:
    """Changes zone to be a new zone in the different region."""
    region = self._utils.GetRegionFromZone(self._GetCurrentZoneFlag())
    self._regions_tried.add(region)
    regions_to_try = (
        {
            self._utils.GetRegionFromZone(zone)
            for zone in self._supported_zones
        }
        - self._regions_tried
    )
    # Restart from empty if we've exhausted all alternatives.
    if not regions_to_try:
      self._regions_tried.clear()
      new_region = self._original_region
    else:
      new_region = random.choice(tuple(regions_to_try))
    logging.info('Retry using new region %s', new_region)
    self._ChooseAndSetNewZone(self._utils.GetZonesInRegion(new_region))

  def _AssignZoneToSameRegion(self) -> None:
    """Changes zone to be a new zone in the same region."""
    supported_zones_in_region = self._utils.GetZonesInRegion(
        self._original_region
    ).intersection(self._supported_zones)
    self._ChooseAndSetNewZone(supported_zones_in_region)

  def _ChooseAndSetNewZone(self, possible_zones: Set[str]) -> None:
    """Saves the current _zone_flag and sets it to a new zone.

    Args:
      possible_zones: The set of zones to choose from.
    """
    current_zone = self._GetCurrentZoneFlag()
    if current_zone != _ANY_ZONE:
      self._zones_tried.add(current_zone)
    zones_to_try = possible_zones - self._zones_tried
    # Restart from empty if we've exhausted all alternatives.
    if not zones_to_try:
      self._zones_tried.clear()
      new_zone = self._original_zone
    else:
      new_zone = random.choice(tuple(zones_to_try))
    logging.info('Retry using new zone %s', new_zone)
    FLAGS[self._zone_flag].unparse()
    FLAGS[self._zone_flag].parse([new_zone])


def _LogCommandLineFlags():
  result = []
  for name in FLAGS:
    flag = FLAGS[name]
    if flag.present:
      result.append(flag.serialize())
  logging.info('Flag values:\n%s', '\n'.join(result))


def SetUpPKB():
  """Set globals and environment variables for PKB.

  After SetUpPKB() returns, it should be possible to call PKB
  functions, like benchmark_spec.Prepare() or benchmark_spec.Run().

  SetUpPKB() also modifies the local file system by creating a temp
  directory and storing new SSH keys.
  """
  try:
    _InitializeRunUri()
  except errors.Error as e:
    logging.error(e)
    sys.exit(1)

  # Initialize logging.
  vm_util.GenTempDir()
  if FLAGS.use_pkb_logging:
    log_util.ConfigureLogging(
        stderr_log_level=log_util.LOG_LEVELS[FLAGS.log_level],
        log_path=vm_util.PrependTempDir(log_util.LOG_FILE_NAME),
        run_uri=FLAGS.run_uri,
        file_log_level=log_util.LOG_LEVELS[FLAGS.file_log_level],
    )
  logging.info('PerfKitBenchmarker version: %s', version.VERSION)

  # Log all provided flag values.
  _LogCommandLineFlags()

  # Register skip pending runs functionality.
  RegisterSkipPendingRunsCheck(_SkipPendingRunsFile)

  # Check environment.
  if not FLAGS.ignore_package_requirements:
    requirements.CheckBasicRequirements()

  for executable in REQUIRED_EXECUTABLES:
    if not vm_util.ExecutableOnPath(executable):
      raise errors.Setup.MissingExecutableError(
          'Could not find required executable "%s"' % executable
      )

  # Check mutually exclusive flags
  if FLAGS.run_stage_iterations > 1 and FLAGS.run_stage_time > 0:
    raise errors.Setup.InvalidFlagConfigurationError(
        'Flags run_stage_iterations and run_stage_time are mutually exclusive'
    )

  vm_util.SSHKeyGen()

  if FLAGS.static_vm_file:
    with open(FLAGS.static_vm_file) as fp:
      static_virtual_machine.StaticVirtualMachine.ReadStaticVirtualMachineFile(
          fp
      )

  benchmark_lookup.SetBenchmarkModuleFunction(benchmark_sets.BenchmarkModule)
  package_lookup.SetPackageModuleFunction(benchmark_sets.PackageModule)

  # Update max_concurrent_threads to use at least as many threads as VMs. This
  # is important for the cluster_boot benchmark where we want to launch the VMs
  # in parallel.
  if not FLAGS.max_concurrent_threads:
    FLAGS.max_concurrent_threads = max(
        background_tasks.MAX_CONCURRENT_THREADS, FLAGS.num_vms
    )
    logging.info(
        'Setting --max_concurrent_threads=%d.', FLAGS.max_concurrent_threads
    )


def RunBenchmarkTasksInSeries(tasks):
  """Runs benchmarks in series.

  Arguments:
    tasks: list of tuples of task: [(RunBenchmarkTask, (spec,), {}),]

  Returns:
    list of tuples of func results
  """
  return [func(*args, **kwargs) for func, args, kwargs in tasks]


def RunBenchmarks():
  """Runs all benchmarks in PerfKitBenchmarker.

  Returns:
    Exit status for the process.
  """
  benchmark_specs = _CreateBenchmarkSpecs()
  if FLAGS.randomize_run_order:
    random.shuffle(benchmark_specs)
  if FLAGS.dry_run:
    print('PKB will run with the following configurations:')
    for spec in benchmark_specs:
      print(spec)
      print('')
    return 0

  benchmark_spec_lists = None
  collector = publisher.SampleCollector()
  try:
    tasks = [(RunBenchmarkTask, (spec,), {}) for spec in benchmark_specs]
    if FLAGS.run_processes is None:
      spec_sample_tuples = RunBenchmarkTasksInSeries(tasks)
    else:
      spec_sample_tuples = background_tasks.RunParallelProcesses(
          tasks, FLAGS.run_processes, FLAGS.run_processes_delay
      )
    benchmark_spec_lists, sample_lists = list(zip(*spec_sample_tuples))
    for sample_list in sample_lists:
      collector.samples.extend(sample_list)

  finally:
    if collector.samples:
      collector.PublishSamples()
    # Use the last run in the series of runs.
    if benchmark_spec_lists:
      benchmark_specs = [spec_list[-1] for spec_list in benchmark_spec_lists]
    if benchmark_specs:
      logging.info(benchmark_status.CreateSummary(benchmark_specs))

    logging.info('Complete logs can be found at: %s', log_util.log_local_path)
    logging.info(
        'Completion statuses can be found at: %s',
        vm_util.PrependTempDir(COMPLETION_STATUS_FILE_NAME),
    )

  if stages.TEARDOWN not in FLAGS.run_stage:
    logging.info(
        'To run again with this setup, please use --run_uri=%s', FLAGS.run_uri
    )

  if FLAGS.archive_bucket:
    archive.ArchiveRun(
        vm_util.GetTempDir(),
        FLAGS.archive_bucket,
        gsutil_path=FLAGS.gsutil_path,
        prefix=FLAGS.run_uri + '_',
    )

  # Write completion status file(s)
  if FLAGS.completion_status_file:
    with open(FLAGS.completion_status_file, 'w') as status_file:
      _WriteCompletionStatusFile(benchmark_specs, status_file)
  completion_status_file_name = vm_util.PrependTempDir(
      COMPLETION_STATUS_FILE_NAME
  )
  with open(completion_status_file_name, 'w') as status_file:
    _WriteCompletionStatusFile(benchmark_specs, status_file)

  # Upload PKB logs to GCS after all benchmark runs are complete.
  log_util.CollectPKBLogs(run_uri=FLAGS.run_uri)
  all_benchmarks_succeeded = all(
      spec.status == benchmark_status.SUCCEEDED for spec in benchmark_specs
  )
  return_code = 0 if all_benchmarks_succeeded else 1
  logging.info('PKB exiting with return_code %s', return_code)
  return return_code


def _GenerateBenchmarkDocumentation():
  """Generates benchmark documentation to show in --help."""
  benchmark_docs = []
  for benchmark_module in (
      linux_benchmarks.BENCHMARKS + windows_benchmarks.BENCHMARKS
  ):
    benchmark_config = configs.LoadMinimalConfig(
        benchmark_module.BENCHMARK_CONFIG, benchmark_module.BENCHMARK_NAME
    )
    vm_groups = benchmark_config.get('vm_groups', {})
    total_vm_count = 0
    vm_str = ''
    scratch_disk_str = ''
    for group in vm_groups.values():
      group_vm_count = group.get('vm_count', 1)
      if group_vm_count is None:
        vm_str = 'variable'
      else:
        total_vm_count += group_vm_count
      if group.get('disk_spec'):
        scratch_disk_str = ' with scratch volume(s)'

    name = benchmark_module.BENCHMARK_NAME
    if benchmark_module in windows_benchmarks.BENCHMARKS:
      name += ' (Windows)'
    benchmark_docs.append(
        '%s: %s (%s VMs%s)'
        % (
            name,
            benchmark_config['description'],
            vm_str or total_vm_count,
            scratch_disk_str,
        )
    )
  return '\n\t'.join(benchmark_docs)


def _CreateCpuVulnerabilitySamples(vms) -> List[sample.Sample]:
  """Returns samples of the VMs' CPU vulernabilites."""

  def CreateSample(vm) -> sample.Sample | None:
    metadata = {'vm_name': vm.name}
    metadata.update(vm.cpu_vulnerabilities.asdict)
    return sample.Sample('cpu_vuln', 0, '', metadata)

  linux_vms = [vm for vm in vms if vm.OS_TYPE in os_types.LINUX_OS_TYPES]
  return background_tasks.RunThreaded(CreateSample, linux_vms)


def _CreateGccSamples(vms):
  """Creates samples from linux VMs of gcc version output."""

  def _GetGccMetadata(vm):
    return {
        'name': vm.name,
        'versiondump': build_tools.GetVersion(vm, 'gcc'),
        'versioninfo': build_tools.GetVersionInfo(vm, 'gcc'),
    }

  linux_vms = [vm for vm in vms if vm.OS_TYPE in os_types.LINUX_OS_TYPES]
  return [
      sample.Sample('gcc_version', 0, '', metadata)
      for metadata in background_tasks.RunThreaded(_GetGccMetadata, linux_vms)
  ]


def _CreateGlibcSamples(vms):
  """Creates glibc samples from linux VMs of ldd output."""

  def _GetGlibcVersionInfo(vm):
    out, _ = vm.RemoteCommand('ldd --version', ignore_failure=True)
    # return first line
    return out.splitlines()[0] if out else None

  def _GetGlibcMetadata(vm):
    return {
        'name': vm.name,
        # TODO(user): Add glibc versiondump.
        'versioninfo': _GetGlibcVersionInfo(vm),
    }

  linux_vms = [vm for vm in vms if vm.OS_TYPE in os_types.LINUX_OS_TYPES]
  return [
      sample.Sample('glibc_version', 0, '', metadata)
      for metadata in background_tasks.RunThreaded(_GetGlibcMetadata, linux_vms)
  ]


def _ParseMeminfo(meminfo_txt: str) -> Tuple[Dict[str, int], List[str]]:
  """Returns the parsed /proc/meminfo data.

  Response has entries such as {'MemTotal' : 32887056, 'Inactive': 4576524}. If
  the /proc/meminfo entry has two values such as
    MemTotal: 32887056 kB
  checks that the last value is 'kB' If it is not then adds that line to the
  2nd value in the tuple.

  Args:
    meminfo_txt: contents of /proc/meminfo

  Returns:
    Tuple where the first entry is a dict of the parsed keys and the second
    are unparsed lines.
  """
  data: Dict[str, int] = {}
  malformed: List[str] = []
  for line in meminfo_txt.splitlines():
    try:
      key, full_value = re.split(r':\s+', line)
      parts = full_value.split()
      if len(parts) == 1 or (len(parts) == 2 and parts[1] == 'kB'):
        data[key] = int(parts[0])
      else:
        malformed.append(line)
    except ValueError:
      # If the line does not match "key: value" or if the value is not an int
      malformed.append(line)
  return data, malformed


@events.benchmark_samples_created.connect
def _CollectMeminfoHandler(
    unused_sender,
    benchmark_spec: bm_spec.BenchmarkSpec,
    samples: List[sample.Sample],
) -> None:
  """Optionally creates /proc/meminfo samples.

  If the flag --collect_meminfo is set appends a sample.Sample of /proc/meminfo
  data for every VM in the run.

  Parameter names cannot be changed as the method is called by events.send with
  keyword arguments.

  Args:
    benchmark_spec: The benchmark spec.
    samples: Generated samples that can be appended to.
  """
  if not pkb_flags.COLLECT_MEMINFO.value:
    return

  def CollectMeminfo(vm):
    txt, _ = vm.RemoteCommand('cat /proc/meminfo')
    meminfo, malformed = _ParseMeminfo(txt)
    meminfo.update({
        'meminfo_keys': ','.join(sorted(meminfo)),
        'meminfo_vmname': vm.name,
        'meminfo_machine_type': vm.machine_type,
        'meminfo_os_type': vm.OS_TYPE,
    })
    if malformed:
      meminfo['meminfo_malformed'] = ','.join(sorted(malformed))
    return sample.Sample('meminfo', 0, '', meminfo)

  linux_vms = [
      vm for vm in benchmark_spec.vms if vm.OS_TYPE in os_types.LINUX_OS_TYPES
  ]

  samples.extend(background_tasks.RunThreaded(CollectMeminfo, linux_vms))


def CaptureVMLogs(
    vms: List[virtual_machine.BaseVirtualMachine],
) -> None:
  """Generates and captures VM logs."""
  if pkb_flags.CAPTURE_VM_LOGS.value:
    for vm in vms:
      vm_log_files = vm.GenerateAndCaptureLogs()
      logging.info(
          'Captured the following logs for VM %s: %s', vm.name, vm_log_files
      )
      for log_path in vm_log_files:
        log_util.CollectVMLogs(FLAGS.run_uri, log_path)


def ParseArgs():
  """Parse command line arguments ."""
  argv = flag_alias.AliasFlagsFromArgs(sys.argv)
  _ParseFlags(argv)
  if FLAGS.helpmatch:
    _PrintHelp(FLAGS.helpmatch)
    return 0
  if FLAGS.helpmatchmd:
    _PrintHelpMD(FLAGS.helpmatchmd)
    return 0

  if not FLAGS.accept_licenses:
    logging.warning(
        'Please run with the --accept_licenses flag to '
        'acknowledge PKB may install software on your behalf.'
    )

  CheckVersionFlag()
  SetUpPKB()


def Main():
  """Entrypoint for PerfKitBenchmarker."""
  assert sys.version_info >= (3, 11), 'PerfKitBenchmarker requires Python 3.11+'
  log_util.ConfigureBasicLogging()
  _InjectBenchmarkInfoIntoDocumentation()
  ParseArgs()
  return RunBenchmarks()
