metaflow/plugins/parallel_decorator.py (151 lines of code) (raw):
from collections import namedtuple
from metaflow.decorators import StepDecorator
from metaflow.unbounded_foreach import UBF_CONTROL, CONTROL_TASK_TAG
from metaflow.exception import MetaflowException
from metaflow.metadata_provider import MetaDatum
from metaflow.metaflow_current import current, Parallel
import os
import sys
class ParallelDecorator(StepDecorator):
"""
MF Add To Current
-----------------
parallel -> metaflow.metaflow_current.Parallel
Returns a namedtuple with relevant information about the parallel task.
@@ Returns
-------
Parallel
`namedtuple` with the following fields:
- main_ip (`str`)
The IP address of the control task.
- num_nodes (`int`)
The total number of tasks created by @parallel
- node_index (`int`)
The index of the current task in all the @parallel tasks.
- control_task_id (`Optional[str]`)
The task ID of the control task. Available to all tasks.
is_parallel -> bool
True if the current step is a @parallel step.
"""
name = "parallel"
defaults = {}
IS_PARALLEL = True
def __init__(self, attributes=None, statically_defined=False):
super(ParallelDecorator, self).__init__(attributes, statically_defined)
def runtime_step_cli(
self, cli_args, retry_count, max_user_code_retries, ubf_context
):
if ubf_context == UBF_CONTROL:
num_parallel = cli_args.task.ubf_iter.num_parallel
cli_args.command_options["num-parallel"] = str(num_parallel)
if os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local":
cli_args.command_options["split_index"] = "0"
def step_init(
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
):
self.environment = environment
# Previously, the `parallel` property was a hardcoded, static property within `current`.
# Whenever `current.parallel` was called, it returned a named tuple with values coming from
# environment variables, loaded dynamically at runtime.
# Now, many of these environment variables are set by compute-related decorators in `task_pre_step`.
# This necessitates ensuring the correct ordering of the `parallel` and compute decorators if we want to
# statically set the namedtuple via `current._update_env` in `task_pre_step`. Hence we avoid using
# `current._update_env` since:
# - it will set a static named tuple, resolving environment variables only once (at the time of calling `current._update_env`).
# - we cannot guarantee the order of calling the decorator's `task_pre_step` (calling `current._update_env` may not set
# the named tuple with the correct values).
# Therefore, we explicitly set the property in `step_init` to ensure the property can resolve the appropriate values in the named tuple
# when accessed at runtime.
setattr(
current.__class__,
"parallel",
property(
fget=lambda _: Parallel(
main_ip=os.environ.get("MF_PARALLEL_MAIN_IP", "127.0.0.1"),
num_nodes=int(os.environ.get("MF_PARALLEL_NUM_NODES", "1")),
node_index=int(os.environ.get("MF_PARALLEL_NODE_INDEX", "0")),
control_task_id=os.environ.get("MF_PARALLEL_CONTROL_TASK_ID", None),
)
),
)
def task_pre_step(
self,
step_name,
task_datastore,
metadata,
run_id,
task_id,
flow,
graph,
retry_count,
max_user_code_retries,
ubf_context,
inputs,
):
from metaflow import current
# Set `is_parallel` to `True` in `current` just like we
# with `is_production` in the project decorator.
current._update_env(
{
"is_parallel": True,
}
)
self.input_paths = [obj.pathspec for obj in inputs]
task_metadata_list = [
MetaDatum(
field="parallel-world-size",
value=flow._parallel_ubf_iter.num_parallel,
type="parallel-world-size",
tags=["attempt_id:{0}".format(0)],
)
]
if ubf_context == UBF_CONTROL:
# A Task's tags are now those of its ancestral Run, so we are not able
# to rely on a task's tags to indicate the presence of a control task
# so, on top of adding the tags above, we also add a task metadata
# entry indicating that this is a "control task".
#
# Here we will also add a task metadata entry to indicate "control
# task". Within the metaflow repo, the only dependency of such a
# "control task" indicator is in the integration test suite (see
# Step.control_tasks() in client API).
task_metadata_list += [
MetaDatum(
field="internal_task_type",
value=CONTROL_TASK_TAG,
type="internal_task_type",
tags=["attempt_id:{0}".format(0)],
)
]
flow._control_task_is_mapper_zero = True
metadata.register_metadata(run_id, step_name, task_id, task_metadata_list)
def task_decorate(
self, step_func, flow, graph, retry_count, max_user_code_retries, ubf_context
):
def _step_func_with_setup():
self.setup_distributed_env(flow)
step_func()
if (
ubf_context == UBF_CONTROL
and os.environ.get("METAFLOW_RUNTIME_ENVIRONMENT", "local") == "local"
):
from functools import partial
env_to_use = getattr(self.environment, "base_env", self.environment)
return partial(
_local_multinode_control_task_step_func,
flow,
env_to_use,
_step_func_with_setup,
retry_count,
",".join(self.input_paths),
)
else:
return _step_func_with_setup
def setup_distributed_env(self, flow):
# Overridden by subclasses to set up particular framework's environment.
pass
def _local_multinode_control_task_step_func(
flow, env_to_use, step_func, retry_count, input_paths
):
"""
Used as multinode UBF control task when run in local mode.
"""
from metaflow import current
from metaflow.cli_args import cli_args
from metaflow.unbounded_foreach import UBF_TASK
import subprocess
assert flow._unbounded_foreach
foreach_iter = flow._parallel_ubf_iter
if foreach_iter.__class__.__name__ != "ParallelUBF":
raise MetaflowException(
"Expected ParallelUBFIter iterator object, got:"
+ foreach_iter.__class__.__name__
)
num_parallel = foreach_iter.num_parallel
os.environ["MF_PARALLEL_NUM_NODES"] = str(num_parallel)
os.environ["MF_PARALLEL_MAIN_IP"] = "127.0.0.1"
os.environ["MF_PARALLEL_CONTROL_TASK_ID"] = str(current.task_id)
run_id = current.run_id
step_name = current.step_name
control_task_id = current.task_id
# UBF handling for multinode case
mapper_task_ids = [control_task_id]
# If we are running inside Conda, we use the base executable FIRST;
# the conda environment will then be used when runtime_step_cli is
# called. This is so that it can properly set up all the metaflow
# aliases needed.
executable = env_to_use.executable(step_name)
script = sys.argv[0]
# start workers
# TODO: Logs for worker processes are assigned to control process as of today, which
# should be fixed at some point
subprocesses = []
for node_index in range(1, num_parallel):
task_id = "%s_node_%d" % (control_task_id, node_index)
mapper_task_ids.append(task_id)
os.environ["MF_PARALLEL_NODE_INDEX"] = str(node_index)
# Override specific `step` kwargs.
kwargs = cli_args.step_kwargs
kwargs["split_index"] = str(node_index)
kwargs["run_id"] = run_id
kwargs["task_id"] = task_id
kwargs["input_paths"] = input_paths
kwargs["ubf_context"] = UBF_TASK
kwargs["retry_count"] = str(retry_count)
cmd = cli_args.step_command(executable, script, step_name, step_kwargs=kwargs)
p = subprocess.Popen(cmd)
subprocesses.append(p)
flow._control_mapper_tasks = [
"%s/%s/%s" % (run_id, step_name, mapper_task_id)
for mapper_task_id in mapper_task_ids
]
# run the step function ourselves
os.environ["MF_PARALLEL_NODE_INDEX"] = "0"
step_func()
# join the subprocesses
for p in subprocesses:
p.wait()
if p.returncode:
raise Exception("Subprocess failed, return code {}".format(p.returncode))