in tfx/orchestration/experimental/core/async_pipeline_task_gen.py [0:0]
def _generate_tasks_for_node(
self, metadata_handler: metadata.Metadata,
node: pipeline_pb2.PipelineNode) -> List[task_lib.Task]:
"""Generates a node execution task.
If a node execution is not feasible, `None` is returned.
Args:
metadata_handler: A handler to access MLMD db.
node: The pipeline node for which to generate a task.
Returns:
Returns a `Task` or `None` if task generation is deemed infeasible.
"""
result = []
node_uid = task_lib.NodeUid.from_pipeline_node(self._pipeline, node)
executions = task_gen_utils.get_executions(metadata_handler, node)
exec_node_task = task_gen_utils.generate_task_from_active_execution(
metadata_handler, self._pipeline, node, executions)
if exec_node_task:
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.RUNNING))
result.append(exec_node_task)
return result
resolved_info = task_gen_utils.generate_resolved_info(
metadata_handler, node)
# TODO(b/207038460): Update async pipeline to support ForEach.
if (resolved_info is None or not resolved_info.input_artifacts or
resolved_info.input_artifacts[0] is None or
not any(resolved_info.input_artifacts[0].values())):
logging.info(
'Task cannot be generated for node %s since no input artifacts '
'are resolved.', node.node_info.id)
return result
input_artifact = resolved_info.input_artifacts[0]
executor_spec_fingerprint = hashlib.sha256()
executor_spec = task_gen_utils.get_executor_spec(
self._pipeline_state.pipeline, node.node_info.id)
if executor_spec is not None:
executor_spec_fingerprint.update(
executor_spec.SerializeToString(deterministic=True))
resolved_info.exec_properties[
constants
.EXECUTOR_SPEC_FINGERPRINT_KEY] = executor_spec_fingerprint.hexdigest()
# If the latest execution had the same resolved input artifacts, execution
# properties and executor specs, we should not trigger a new execution.
latest_exec = task_gen_utils.get_latest_execution(executions)
if latest_exec:
artifact_ids_by_event_type = (
execution_lib.get_artifact_ids_by_event_type_for_execution_id(
metadata_handler, latest_exec.id))
latest_exec_input_artifact_ids = artifact_ids_by_event_type.get(
metadata_store_pb2.Event.INPUT, set())
current_exec_input_artifact_ids = set(
a.id for a in itertools.chain(*input_artifact.values()))
latest_exec_properties = task_gen_utils.extract_properties(latest_exec)
current_exec_properties = resolved_info.exec_properties
latest_exec_executor_spec_fp = latest_exec_properties[
constants.EXECUTOR_SPEC_FINGERPRINT_KEY]
current_exec_executor_spec_fp = resolved_info.exec_properties[
constants.EXECUTOR_SPEC_FINGERPRINT_KEY]
if (latest_exec_input_artifact_ids == current_exec_input_artifact_ids and
_exec_properties_match(latest_exec_properties,
current_exec_properties) and
latest_exec_executor_spec_fp == current_exec_executor_spec_fp):
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.STARTED))
return result
execution = execution_publish_utils.register_execution(
metadata_handler=metadata_handler,
execution_type=node.node_info.type,
contexts=resolved_info.contexts,
input_artifacts=input_artifact,
exec_properties=resolved_info.exec_properties)
outputs_resolver = outputs_utils.OutputsResolver(
node, self._pipeline.pipeline_info, self._pipeline.runtime_spec,
self._pipeline.execution_mode)
# For mixed service nodes, we ensure node services and check service
# status; the node is aborted if its service jobs have failed.
service_status = self._ensure_node_services_if_mixed(node.node_info.id)
if service_status is not None:
if service_status != service_jobs.ServiceStatus.RUNNING:
error_msg = f'associated service job failed; node uid: {node_uid}'
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid,
state=pstate.NodeState.FAILED,
status=status_lib.Status(
code=status_lib.Code.ABORTED, message=error_msg)))
return result
output_artifacts = outputs_resolver.generate_output_artifacts(execution.id)
outputs_utils.make_output_dirs(output_artifacts)
result.append(
task_lib.UpdateNodeStateTask(
node_uid=node_uid, state=pstate.NodeState.RUNNING))
result.append(
task_lib.ExecNodeTask(
node_uid=node_uid,
execution_id=execution.id,
contexts=resolved_info.contexts,
input_artifacts=input_artifact,
exec_properties=resolved_info.exec_properties,
output_artifacts=output_artifacts,
executor_output_uri=outputs_resolver.get_executor_output_uri(
execution.id),
stateful_working_dir=outputs_resolver
.get_stateful_working_directory(execution.id),
pipeline=self._pipeline))
return result