def _generate_tasks_for_node()

in tfx/orchestration/experimental/core/async_pipeline_task_gen.py [0:0]


  def _generate_tasks_for_node(
      self, metadata_handler: metadata.Metadata,
      node: pipeline_pb2.PipelineNode) -> List[task_lib.Task]:
    """Generates a node execution task.

    If a node execution is not feasible, `None` is returned.

    Args:
      metadata_handler: A handler to access MLMD db.
      node: The pipeline node for which to generate a task.

    Returns:
      Returns a `Task` or `None` if task generation is deemed infeasible.
    """
    result = []
    node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)

    executions = task_gen_utils.get_executions(metadata_handler, node)
    exec_node_task = task_gen_utils.generate_task_from_active_execution(
        metadata_handler, self._pipeline, node, executions)
    if exec_node_task:
      result.append(
          task_lib.UpdateNodeStateTask(
              node_uid=node_uid, state=pstate.NodeState.RUNNING))
      result.append(exec_node_task)
      return result

    resolved_info = task_gen_utils.generate_resolved_info(
        metadata_handler, node)
    # TODO(b/207038460): Update async pipeline to support ForEach.
    if (resolved_info is None or not resolved_info.input_artifacts or
        resolved_info.input_artifacts[0] is None or
        not any(resolved_info.input_artifacts[0].values())):
      logging.info(
          'Task cannot be generated for node %s since no input artifacts '
          'are resolved.', node.node_info.id)
      return result
    input_artifact = resolved_info.input_artifacts[0]

    executor_spec_fingerprint = hashlib.sha256()
    executor_spec = task_gen_utils.get_executor_spec(
        self._pipeline_state.pipeline, node.node_info.id)
    if executor_spec is not None:
      executor_spec_fingerprint.update(
          executor_spec.SerializeToString(deterministic=True))
    resolved_info.exec_properties[
        constants
        .EXECUTOR_SPEC_FINGERPRINT_KEY] = executor_spec_fingerprint.hexdigest()

    # If the latest execution had the same resolved input artifacts, execution
    # properties and executor specs, we should not trigger a new execution.
    latest_exec = task_gen_utils.get_latest_execution(executions)
    if latest_exec:
      artifact_ids_by_event_type = (
          execution_lib.get_artifact_ids_by_event_type_for_execution_id(
              metadata_handler, latest_exec.id))
      latest_exec_input_artifact_ids = artifact_ids_by_event_type.get(
          metadata_store_pb2.Event.INPUT, set())
      current_exec_input_artifact_ids = set(
          a.id for a in itertools.chain(*input_artifact.values()))
      latest_exec_properties = task_gen_utils.extract_properties(latest_exec)
      current_exec_properties = resolved_info.exec_properties
      latest_exec_executor_spec_fp = latest_exec_properties[
          constants.EXECUTOR_SPEC_FINGERPRINT_KEY]
      current_exec_executor_spec_fp = resolved_info.exec_properties[
          constants.EXECUTOR_SPEC_FINGERPRINT_KEY]
      if (latest_exec_input_artifact_ids == current_exec_input_artifact_ids and
          _exec_properties_match(latest_exec_properties,
                                 current_exec_properties) and
          latest_exec_executor_spec_fp == current_exec_executor_spec_fp):
        result.append(
            task_lib.UpdateNodeStateTask(
                node_uid=node_uid, state=pstate.NodeState.STARTED))
        return result

    execution = execution_publish_utils.register_execution(
        metadata_handler=metadata_handler,
        execution_type=node.node_info.type,
        contexts=resolved_info.contexts,
        input_artifacts=input_artifact,
        exec_properties=resolved_info.exec_properties)
    outputs_resolver = outputs_utils.OutputsResolver(
        node, self._pipeline.pipeline_info, self._pipeline.runtime_spec,
        self._pipeline.execution_mode)

    # For mixed service nodes, we ensure node services and check service
    # status; the node is aborted if its service jobs have failed.
    service_status = self._ensure_node_services_if_mixed(node.node_info.id)
    if service_status is not None:
      if service_status != service_jobs.ServiceStatus.RUNNING:
        error_msg = f'associated service job failed; node uid: {node_uid}'
        result.append(
            task_lib.UpdateNodeStateTask(
                node_uid=node_uid,
                state=pstate.NodeState.FAILED,
                status=status_lib.Status(
                    code=status_lib.Code.ABORTED, message=error_msg)))
        return result

    output_artifacts = outputs_resolver.generate_output_artifacts(execution.id)
    outputs_utils.make_output_dirs(output_artifacts)
    result.append(
        task_lib.UpdateNodeStateTask(
            node_uid=node_uid, state=pstate.NodeState.RUNNING))
    result.append(
        task_lib.ExecNodeTask(
            node_uid=node_uid,
            execution_id=execution.id,
            contexts=resolved_info.contexts,
            input_artifacts=input_artifact,
            exec_properties=resolved_info.exec_properties,
            output_artifacts=output_artifacts,
            executor_output_uri=outputs_resolver.get_executor_output_uri(
                execution.id),
            stateful_working_dir=outputs_resolver
            .get_stateful_working_directory(execution.id),
            pipeline=self._pipeline))
    return result