def run_elastic()

in horovod/spark/runner.py [0:0]


def run_elastic(fn, args=(), kwargs={}, num_proc=None, min_np=None, max_np=None,
                start_timeout=None, elastic_timeout=None, reset_limit=None, env=None, verbose=1, nics=None):
    """
    Runs Elastic Horovod on Spark.  Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.

    Args:
        fn: Function to run.
        args: Arguments to pass to `fn`.
        kwargs: Keyword arguments to pass to `fn`.
        num_proc: Number of Horovod processes.  Defaults to `spark.default.parallelism`.
        start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
                       If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        elastic_timeout: Timeout for elastic initialisation after re-scaling the cluster.
                       If not set, falls back to `HOROVOD_ELASTIC_TIMEOUT` environment variable value.
                       If it is not set as well, defaults to 600 seconds.
        reset_limit: Maximum number of resets after which the job is terminated.
        env: Environment dictionary to use in Horovod run.  Defaults to `os.environ`.
        verbose: Debug output verbosity (0-2). Defaults to 1.
        nics: List of NICs for tcp network communication.

    Returns:
        List of results returned by running `fn` on each rank.
    """

    if not gloo_built(verbose=(verbose >= 2)):
        raise ValueError('Gloo support is required to use elastic training, but has not been built.  Ensure CMake is '
                         'installed and reinstall Horovod with HOROVOD_WITH_GLOO=1 to debug the build error.')

    spark_context = pyspark.SparkContext._active_spark_context
    if spark_context is None:
        raise Exception('Could not find an active SparkContext, are you '
                        'running in a PySpark session?')

    if start_timeout is None:
        # Lookup default timeout from the environment variable.
        start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))

    # nics needs to be a set
    if nics and not isinstance(nics, set):
        nics = set(nics)

    if num_proc is None:
        # TODO: #2023 try spark.dynamicAllocation.initialExecutors
        num_proc = spark_context.defaultParallelism
        if verbose >= 1:
            logging.info('Running %d processes (inferred from spark.default.parallelism)...', num_proc)
    else:
        if verbose >= 1:
            logging.info('Running %d processes...', num_proc)

    if min_np is None:
        # TODO: #2023 try spark.dynamicAllocation.minExecutors
        min_np = num_proc
    if max_np is None:
        # TODO: #2023 try spark.dynamicAllocation.maxExecutors
        max_np = num_proc

    # start Spark driver service and launch settings.num_proc Spark tasks
    key = secret.make_secret_key()
    spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
    driver = driver_service.SparkDriverService(num_proc, max_np,
                                               fn, args, kwargs,
                                               key, nics)

    discovery = host_discovery.SparkDriverHostDiscovery(driver)

    tmout = timeout.Timeout(start_timeout,
                            message='Timed out waiting for {activity}. Please check that you have '
                                    'enough resources to run all Horovod processes. Each Horovod '
                                    'process runs in a Spark task. You may need to increase the '
                                    'start_timeout parameter to a larger value if your Spark resources '
                                    'are allocated on-demand.')
    settings = hvd_elastic_settings.ElasticSettings(discovery=discovery,
                                                    min_np=min_np,
                                                    max_np=max_np,
                                                    elastic_timeout=elastic_timeout,
                                                    reset_limit=reset_limit,
                                                    num_proc=num_proc,
                                                    verbose=verbose,
                                                    key=key,
                                                    start_timeout=tmout,
                                                    nics=nics,
                                                    run_func_mode=True)

    result_queue = queue.Queue(1)

    # launch settings.num_proc / settings.max_np Spark tasks
    spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
                                      result_queue, settings, use_gloo=True, is_elastic=True)
    try:
        # Register task addresses of initial num_proc tasks
        _register_task_addresses(driver, settings)

        # Run the job
        gloo_run_elastic(settings, driver, env)
    except:
        # Terminate Spark job.
        spark_context.cancelJobGroup(spark_job_group)

        # Re-raise exception.
        raise
    finally:
        spark_thread.join()
        driver.shutdown()

    # Make sure Spark Job did not fail.
    driver.check_for_spark_job_failure()

    # get ranks from driver
    indices_in_rank_order = _get_indices_in_rank_order(driver)

    # If there's no exception, execution results are in this queue.
    results = result_queue.get_nowait()
    return [results[index] for index in indices_in_rank_order]