in horovod/spark/runner.py [0:0]
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
# deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
# Spark RPC communicates the key and supports encryption
# for convenience, we put it back into settings
settings.key = key
# to simplify things, each task is an individual host in Elastic Horovod on Spark
# further, each attempt (instance) of a task is an individual host in Elastic Horovod on Spark
# hides availability of shared memory among executors on the same Spark node
hosthash = host_hash(salt='{}-{}'.format(index, time.time()) if is_elastic else None)
# provide host hash to mpirun_exec_fn.py via task service
# gloo_exec_fn.py will get this env var set in request env as well
os.environ['HOROVOD_HOSTNAME'] = hosthash
task = task_service.SparkTaskService(index, settings.key, settings.nics,
MINIMUM_COMMAND_LIFETIME_S if is_elastic or use_gloo else None,
settings.verbose)
try:
driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
driver_client.register_task(index, task.addresses(), hosthash)
if not is_elastic:
task.wait_for_initial_registration(settings.start_timeout)
task_indices_on_this_host = driver_client.task_host_hash_indices(hosthash)
local_rank_zero_index = task_indices_on_this_host[0]
else:
local_rank_zero_index = None
# In elastic all tasks wait for task shutdown signal from driver.
# With Gloo all tasks wait for the command to start and terminate.
# With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
if is_elastic:
# either terminate on task shutdown or command termination
shutdown_thread = in_thread(driver_client.wait_for_task_shutdown)
while shutdown_thread.is_alive():
# Once the command started we wait for its termination
if task.check_for_command_start(WAIT_FOR_COMMAND_START_DELAY_SECONDS):
task.wait_for_command_termination()
if task.command_exit_code() != 0:
raise Exception('Command failed, making Spark task fail to restart the task')
break
# While no command started, we can shutdown any time
shutdown_thread.join(WAIT_FOR_SHUTDOWN_DELAY_SECONDS)
elif use_gloo or index == local_rank_zero_index:
# Either Gloo or first task with MPI.
task.wait_for_command_start(settings.start_timeout)
task.wait_for_command_termination()
else:
# The other tasks with MPI need to wait for the first task to finish.
first_task_addresses = driver_client.all_task_addresses(local_rank_zero_index)
first_task_client = \
task_service.SparkTaskClient(local_rank_zero_index,
first_task_addresses, settings.key,
settings.verbose)
first_task_client.wait_for_command_termination()
return task.fn_result()
finally:
# we must not call into shutdown too quickly, task clients run a command
# and want to wait on the result, we have told task service not to return
# from wait_for_command_termination too quickly, so we are safe here to shutdown
# clients have had enough time to connect to the service already
#
# the shutdown has to block on running requests (wait_for_command_exit_code)
# so they can finish serving the exit code
# shutdown does block with network.BasicService._server._block_on_close = True
task.shutdown()