def run()

in horovod/spark/runner.py [0:0]


def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None,
        use_mpi=None, use_gloo=None, extra_mpi_args=None,
        env=None, stdout=None, stderr=None, verbose=1, nics=None):
    """
    Runs Horovod on Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
        env: Environment dictionary to use in Horovod run.
        stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
        stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please check that you have '
                                    'enough resources to run all Horovod processes. Each Horovod '
                                    'process runs in a Spark task. You may need to increase the '
                                    'start_timeout parameter to a larger value if your Spark resources '
                                    'are allocated on-demand.')
    settings = hvd_settings.Settings(verbose=verbose,
                                     extra_mpi_args=extra_mpi_args,
                                     key=secret.make_secret_key(),
                                     start_timeout=tmout,
                                     nics=nics,
                                     run_func_mode=True)

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if num_proc is None:
        num_proc = spark_context.defaultParallelism
        if settings.verbose >= 1:
            logging.info('Running %d processes (inferred from spark.default.parallelism)...', num_proc)
    else:
        if settings.verbose >= 1:
            logging.info('Running %d processes...', num_proc)
    settings.num_proc = num_proc

    result_queue = queue.Queue(1)

    # start Spark driver service and launch settings.num_proc Spark tasks
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(settings.num_proc, settings.num_proc,
                                               fn, args, kwargs,
                                               settings.key, settings.nics)
    gloo_is_used = is_gloo_used(use_gloo=use_gloo, use_mpi=use_mpi, use_jsrun=False)
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings,
                                      use_gloo=gloo_is_used, is_elastic=False)
    try:
        # wait for all tasks to register, notify them and initiate task-to-task address registration
        _notify_and_register_task_addresses(driver, settings)

        # Determine the index grouping based on host hashes.
        # Barrel shift until index 0 is in the first host.
        host_hashes = list(driver.task_host_hash_indices().keys())
        host_hashes.sort()
        while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
            host_hashes = host_hashes[1:] + host_hashes[:1]

        settings.hosts = ','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash]))
                                  for host_hash in host_hashes)

        # Run the job
        _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # get ranks from driver
    indices_in_rank_order = _get_indices_in_rank_order(driver)

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in indices_in_rank_order]