def _batch()

in metaflow/plugins/aws/step_functions/step_functions.py [0:0]


    def _batch(self, node):
        attrs = {
            # metaflow.user is only used for setting the AWS Job Name.
            # Since job executions are no longer tied to a specific user
            # identity, we will just set their user to `SFN`. We still do need
            # access to the owner of the workflow for production tokens, which
            # we can stash in metaflow.owner.
            "metaflow.user": "SFN",
            "metaflow.owner": self.username,
            "metaflow.flow_name": self.flow.name,
            "metaflow.step_name": node.name,
            # Unfortunately we can't set the task id here since AWS Step
            # Functions lacks any notion of run-scoped task identifiers. We
            # instead co-opt the AWS Batch job id as the task id. This also
            # means that the AWS Batch job name will have missing fields since
            # the job id is determined at job execution, but since the job id is
            # part of the job description payload, we don't lose much except for
            # a few ugly looking black fields in the AWS Batch UI.
            # Also, unfortunately we can't set the retry count since
            # `$$.State.RetryCount` resolves to an int dynamically and
            # AWS Batch job specification only accepts strings. We handle
            # retries/catch within AWS Batch to get around this limitation.
            # And, we also cannot set the run id here since the run id maps to
            # the execution name of the AWS Step Functions State Machine, which
            # is different when executing inside a distributed map. We set it once
            # in the start step and move it along to be consumed by all the children.
            "metaflow.version": self.environment.get_environment_info()[
                "metaflow_version"
            ],
            # We rely on step names and task ids of parent steps to construct
            # input paths for a task. Since the only information we can pass
            # between states (via `InputPath` and `ResultPath`) in AWS Step
            # Functions is the job description, we run the risk of exceeding
            # 32K state size limit rather quickly if we don't filter the job
            # description to a minimal set of fields. Unfortunately, the partial
            # `JsonPath` implementation within AWS Step Functions makes this
            # work a little non-trivial; it doesn't like dots in keys, so we
            # have to add the field again.
            # This pattern is repeated in a lot of other places, where we use
            # AWS Batch parameters to store AWS Step Functions state
            # information, since this field is the only field in the AWS Batch
            # specification that allows us to set key-values.
            "step_name": node.name,
        }

        # Store production token within the `start` step, so that subsequent
        # `step-functions create` calls can perform a rudimentary authorization
        # check.
        if node.name == "start":
            attrs["metaflow.production_token"] = self.production_token

        # Add env vars from the optional @environment decorator.
        env_deco = [deco for deco in node.decorators if deco.name == "environment"]
        env = {}
        if env_deco:
            env = env_deco[0].attributes["vars"].copy()

        # add METAFLOW_S3_ENDPOINT_URL
        if S3_ENDPOINT_URL is not None:
            env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL

        if node.name == "start":
            # metaflow.run_id maps to AWS Step Functions State Machine Execution in all
            # cases except for when within a for-each construct that relies on
            # Distributed Map. To work around this issue, we pass the run id from the
            # start step to all subsequent tasks.
            attrs["metaflow.run_id.$"] = "$$.Execution.Name"

            # Initialize parameters for the flow in the `start` step.
            parameters = self._process_parameters()
            if parameters:
                # Get user-defined parameters from State Machine Input.
                # Since AWS Step Functions doesn't allow for optional inputs
                # currently, we have to unfortunately place an artificial
                # constraint that every parameterized workflow needs to include
                # `Parameters` as a key in the input to the workflow.
                # `step-functions trigger` already takes care of this
                # requirement, but within the UI, the users will be required to
                # specify an input with key as `Parameters` and value as a
                # stringified json of the actual parameters -
                # {"Parameters": "{\"alpha\": \"beta\"}"}
                env["METAFLOW_PARAMETERS"] = "$.Parameters"
                default_parameters = {}
                for parameter in parameters:
                    if parameter["value"] is not None:
                        default_parameters[parameter["name"]] = parameter["value"]
                # Dump the default values specified in the flow.
                env["METAFLOW_DEFAULT_PARAMETERS"] = json.dumps(default_parameters)
            # `start` step has no upstream input dependencies aside from
            # parameters.
            input_paths = None
        else:
            # We need to rely on the `InputPath` of the AWS Step Functions
            # specification to grab task ids and the step names of the parent
            # to properly construct input_paths at runtime. Thanks to the
            # JsonPath-foo embedded in the parent states, we have this
            # information easily available.

            if node.parallel_foreach:
                raise StepFunctionsException(
                    "Parallel steps are not supported yet with AWS step functions."
                )

            # Handle foreach join.
            if (
                node.type == "join"
                and self.graph[node.split_parents[-1]].type == "foreach"
            ):
                input_paths = (
                    "sfn-${METAFLOW_RUN_ID}/%s/:"
                    "${METAFLOW_PARENT_TASK_IDS}" % node.in_funcs[0]
                )
                # Unfortunately, AWS Batch only allows strings as value types
                # in its specification, and we don't have any way to concatenate
                # the task ids array from the parent steps within AWS Step
                # Functions and pass it down to AWS Batch. We instead have to
                # rely on publishing the state to DynamoDb and fetching it back
                # in within the AWS Batch entry point to set
                # `METAFLOW_PARENT_TASK_IDS`. The state is scoped to the parent
                # foreach task `METAFLOW_SPLIT_PARENT_TASK_ID`. We decided on
                # AWS DynamoDb and not AWS Lambdas, because deploying and
                # debugging Lambdas would be a nightmare as far as OSS support
                # is concerned.
                env["METAFLOW_SPLIT_PARENT_TASK_ID"] = (
                    "$.Parameters.split_parent_task_id_%s" % node.split_parents[-1]
                )
                # Inherit the run id from the parent and pass it along to children.
                attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
            else:
                # Set appropriate environment variables for runtime replacement.
                if len(node.in_funcs) == 1:
                    input_paths = (
                        "sfn-${METAFLOW_RUN_ID}/%s/${METAFLOW_PARENT_TASK_ID}"
                        % node.in_funcs[0]
                    )
                    env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
                    # Inherit the run id from the parent and pass it along to children.
                    attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
                else:
                    # Generate the input paths in a quasi-compressed format.
                    # See util.decompress_list for why this is written the way
                    # it is.
                    input_paths = "sfn-${METAFLOW_RUN_ID}:" + ",".join(
                        "/${METAFLOW_PARENT_%s_STEP}/"
                        "${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx)
                        for idx, _ in enumerate(node.in_funcs)
                    )
                    # Inherit the run id from the parent and pass it along to children.
                    attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
                    for idx, _ in enumerate(node.in_funcs):
                        env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
                        env["METAFLOW_PARENT_%s_STEP" % idx] = (
                            "$.[%s].Parameters.step_name" % idx
                        )
            env["METAFLOW_INPUT_PATHS"] = input_paths

            if node.is_inside_foreach:
                # Set the task id of the parent job of the foreach split in
                # our favorite dumping ground, the AWS Batch attrs. For
                # subsequent descendent tasks, this attrs blob becomes the
                # input to those descendent tasks. We set and propagate the
                # task ids pointing to split_parents through every state.
                if any(self.graph[n].type == "foreach" for n in node.in_funcs):
                    attrs["split_parent_task_id_%s.$" % node.split_parents[-1]] = (
                        "$.SplitParentTaskId"
                    )
                    for parent in node.split_parents[:-1]:
                        if self.graph[parent].type == "foreach":
                            attrs["split_parent_task_id_%s.$" % parent] = (
                                "$.Parameters.split_parent_task_id_%s" % parent
                            )
                elif node.type == "join":
                    if self.graph[node.split_parents[-1]].type == "foreach":
                        # A foreach join only gets one set of input from the
                        # parent tasks. We filter the Map state to only output
                        # `$.[0]`, since we don't need any of the other outputs,
                        # that information is available to us from AWS DynamoDB.
                        # This has a nice side effect of making our foreach
                        # splits infinitely scalable because otherwise we would
                        # be bounded by the 32K state limit for the outputs. So,
                        # instead of referencing `Parameters` fields by index
                        # (like in `split`), we can just reference them
                        # directly.
                        attrs["split_parent_task_id_%s.$" % node.split_parents[-1]] = (
                            "$.Parameters.split_parent_task_id_%s"
                            % node.split_parents[-1]
                        )
                        for parent in node.split_parents[:-1]:
                            if self.graph[parent].type == "foreach":
                                attrs["split_parent_task_id_%s.$" % parent] = (
                                    "$.Parameters.split_parent_task_id_%s" % parent
                                )
                    else:
                        for parent in node.split_parents:
                            if self.graph[parent].type == "foreach":
                                attrs["split_parent_task_id_%s.$" % parent] = (
                                    "$.[0].Parameters.split_parent_task_id_%s" % parent
                                )
                else:
                    for parent in node.split_parents:
                        if self.graph[parent].type == "foreach":
                            attrs["split_parent_task_id_%s.$" % parent] = (
                                "$.Parameters.split_parent_task_id_%s" % parent
                            )

                # Set `METAFLOW_SPLIT_PARENT_TASK_ID_FOR_FOREACH_JOIN` if the
                # next transition is to a foreach join, so that the
                # stepfunctions decorator can write the mapping for input path
                # to DynamoDb.
                if any(
                    self.graph[n].type == "join"
                    and self.graph[self.graph[n].split_parents[-1]].type == "foreach"
                    for n in node.out_funcs
                ):
                    env["METAFLOW_SPLIT_PARENT_TASK_ID_FOR_FOREACH_JOIN"] = attrs[
                        "split_parent_task_id_%s.$"
                        % self.graph[node.out_funcs[0]].split_parents[-1]
                    ]

                # Set ttl for the values we set in AWS DynamoDB.
                if node.type == "foreach":
                    if self.workflow_timeout:
                        env["METAFLOW_SFN_WORKFLOW_TIMEOUT"] = self.workflow_timeout

            # Handle split index for for-each.
            if any(self.graph[n].type == "foreach" for n in node.in_funcs):
                env["METAFLOW_SPLIT_INDEX"] = "$.Index"

        env["METAFLOW_CODE_URL"] = self.code_package_url
        env["METAFLOW_FLOW_NAME"] = attrs["metaflow.flow_name"]
        env["METAFLOW_STEP_NAME"] = attrs["metaflow.step_name"]
        env["METAFLOW_RUN_ID"] = attrs["metaflow.run_id.$"]
        env["METAFLOW_PRODUCTION_TOKEN"] = self.production_token
        env["SFN_STATE_MACHINE"] = self.name
        env["METAFLOW_OWNER"] = attrs["metaflow.owner"]
        # Can't set `METAFLOW_TASK_ID` due to lack of run-scoped identifiers.
        # We will instead rely on `AWS_BATCH_JOB_ID` as the task identifier.
        # Can't set `METAFLOW_RETRY_COUNT` either due to integer casting issue.
        metadata_env = self.metadata.get_runtime_environment("step-functions")
        env.update(metadata_env)

        metaflow_version = self.environment.get_environment_info()
        metaflow_version["flow_name"] = self.graph.name
        metaflow_version["production_token"] = self.production_token
        env["METAFLOW_VERSION"] = json.dumps(metaflow_version)

        # map config values
        cfg_env = {param["name"]: param["kv_name"] for param in self.config_parameters}
        if cfg_env:
            env["METAFLOW_FLOW_CONFIG_VALUE"] = json.dumps(cfg_env)

        # Set AWS DynamoDb Table Name for state tracking for for-eaches.
        # There are three instances when metaflow runtime directly interacts
        # with AWS DynamoDB.
        #   1. To set the cardinality of `foreach`s (which are subsequently)
        #      read prior to the instantiation of the Map state by AWS Step
        #      Functions.
        #   2. To set the input paths from the parent steps of a foreach join.
        #   3. To read the input paths in a foreach join.
        if (
            node.type == "foreach"
            or (
                node.is_inside_foreach
                and any(
                    self.graph[n].type == "join"
                    and self.graph[self.graph[n].split_parents[-1]].type == "foreach"
                    for n in node.out_funcs
                )
            )
            or (
                node.type == "join"
                and self.graph[node.split_parents[-1]].type == "foreach"
            )
        ):
            if SFN_DYNAMO_DB_TABLE is None:
                raise StepFunctionsException(
                    "An AWS DynamoDB table is needed "
                    "to support foreach in your flow. "
                    "You can create one following the "
                    "instructions listed at *https://a"
                    "dmin-docs.metaflow.org/metaflow-o"
                    "n-aws/deployment-guide/manual-dep"
                    "loyment#scheduling* and "
                    "re-configure Metaflow using "
                    "*metaflow configure aws* on your "
                    "terminal."
                )
            env["METAFLOW_SFN_DYNAMO_DB_TABLE"] = SFN_DYNAMO_DB_TABLE

        # It makes no sense to set env vars to None (shows up as "None" string)
        env = {k: v for k, v in env.items() if v is not None}

        # Resolve AWS Batch resource requirements.
        batch_deco = [deco for deco in node.decorators if deco.name == "batch"][0]
        resources = {}
        resources.update(batch_deco.attributes)
        # Resolve retry strategy.
        user_code_retries, total_retries = self._get_retries(node)

        task_spec = {
            "flow_name": attrs["metaflow.flow_name"],
            "step_name": attrs["metaflow.step_name"],
            "run_id": "sfn-$METAFLOW_RUN_ID",
            # Use AWS Batch job identifier as the globally unique
            # task identifier.
            "task_id": "$AWS_BATCH_JOB_ID",
            # Since retries are handled by AWS Batch, we can rely on
            # AWS_BATCH_JOB_ATTEMPT as the job counter.
            "retry_count": "$((AWS_BATCH_JOB_ATTEMPT-1))",
        }

        return (
            Batch(self.metadata, self.environment)
            .create_job(
                step_name=node.name,
                step_cli=self._step_cli(
                    node, input_paths, self.code_package_url, user_code_retries
                ),
                task_spec=task_spec,
                code_package_sha=self.code_package_sha,
                code_package_url=self.code_package_url,
                code_package_ds=self.flow_datastore.TYPE,
                image=resources["image"],
                queue=resources["queue"],
                iam_role=resources["iam_role"],
                execution_role=resources["execution_role"],
                cpu=resources["cpu"],
                gpu=resources["gpu"],
                memory=resources["memory"],
                run_time_limit=batch_deco.run_time_limit,
                shared_memory=resources["shared_memory"],
                max_swap=resources["max_swap"],
                swappiness=resources["swappiness"],
                efa=resources["efa"],
                use_tmpfs=resources["use_tmpfs"],
                tmpfs_tempdir=resources["tmpfs_tempdir"],
                tmpfs_size=resources["tmpfs_size"],
                tmpfs_path=resources["tmpfs_path"],
                inferentia=resources["inferentia"],
                env=env,
                attrs=attrs,
                host_volumes=resources["host_volumes"],
                efs_volumes=resources["efs_volumes"],
                ephemeral_storage=resources["ephemeral_storage"],
                log_driver=resources["log_driver"],
                log_options=resources["log_options"],
            )
            .attempts(total_retries + 1)
        )