in horovod/spark/runner.py [0:0]
def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None,
use_mpi=None, use_gloo=None, extra_mpi_args=None,
env=None, stdout=None, stderr=None, verbose=1, nics=None):
"""
Runs Horovod on Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks.
Args:
fn: Function to run.
args: Arguments to pass to `fn`.
kwargs: Keyword arguments to pass to `fn`.
num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`.
start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds.
If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value.
If it is not set as well, defaults to 600 seconds.
extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args.
env: Environment dictionary to use in Horovod run.
stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout.
stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr.
verbose: Debug output verbosity (0-2). Defaults to 1.
nics: List of NICs for tcp network communication.
Returns:
List of results returned by running `fn` on each rank.
"""
if start_timeout is None:
# Lookup default timeout from the environment variable.
start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600'))
# nics needs to be a set
if nics and not isinstance(nics, set):
nics = set(nics)
tmout = timeout.Timeout(start_timeout,
message='Timed out waiting for {activity}. Please check that you have '
'enough resources to run all Horovod processes. Each Horovod '
'process runs in a Spark task. You may need to increase the '
'start_timeout parameter to a larger value if your Spark resources '
'are allocated on-demand.')
settings = hvd_settings.Settings(verbose=verbose,
extra_mpi_args=extra_mpi_args,
key=secret.make_secret_key(),
start_timeout=tmout,
nics=nics,
run_func_mode=True)
spark_context = pyspark.SparkContext._active_spark_context
if spark_context is None:
raise Exception('Could not find an active SparkContext, are you '
'running in a PySpark session?')
if num_proc is None:
num_proc = spark_context.defaultParallelism
if settings.verbose >= 1:
logging.info('Running %d processes (inferred from spark.default.parallelism)...', num_proc)
else:
if settings.verbose >= 1:
logging.info('Running %d processes...', num_proc)
settings.num_proc = num_proc
result_queue = queue.Queue(1)
# start Spark driver service and launch settings.num_proc Spark tasks
spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id()
driver = driver_service.SparkDriverService(settings.num_proc, settings.num_proc,
fn, args, kwargs,
settings.key, settings.nics)
gloo_is_used = is_gloo_used(use_gloo=use_gloo, use_mpi=use_mpi, use_jsrun=False)
spark_thread = _make_spark_thread(spark_context, spark_job_group, driver,
result_queue, settings,
use_gloo=gloo_is_used, is_elastic=False)
try:
# wait for all tasks to register, notify them and initiate task-to-task address registration
_notify_and_register_task_addresses(driver, settings)
# Determine the index grouping based on host hashes.
# Barrel shift until index 0 is in the first host.
host_hashes = list(driver.task_host_hash_indices().keys())
host_hashes.sort()
while 0 not in driver.task_host_hash_indices()[host_hashes[0]]:
host_hashes = host_hashes[1:] + host_hashes[:1]
settings.hosts = ','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash]))
for host_hash in host_hashes)
# Run the job
_launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr)
except:
# Terminate Spark job.
spark_context.cancelJobGroup(spark_job_group)
# Re-raise exception.
raise
finally:
spark_thread.join()
driver.shutdown()
# Make sure Spark Job did not fail.
driver.check_for_spark_job_failure()
# get ranks from driver
indices_in_rank_order = _get_indices_in_rank_order(driver)
# If there's no exception, execution results are in this queue.
results = result_queue.get_nowait()
return [results[index] for index in indices_in_rank_order]