tfx/dsl/compiler/node_inputs_compiler.py (444 lines of code) (raw):
# Copyright 2022 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compiler submodule specialized for NodeInputs."""
from collections.abc import Iterable, Sequence
import functools
from typing import Optional, Type, cast
from tfx import types
from tfx.dsl.compiler import compiler_context
from tfx.dsl.compiler import compiler_utils
from tfx.dsl.compiler import constants
from tfx.dsl.compiler import node_contexts_compiler
from tfx.dsl.components.base import base_component
from tfx.dsl.components.base import base_node
from tfx.dsl.experimental.conditionals import conditional
from tfx.dsl.input_resolution import resolver_op
from tfx.dsl.placeholder import artifact_placeholder
from tfx.dsl.placeholder import placeholder
from tfx.orchestration import data_types_utils
from tfx.orchestration import pipeline
from tfx.proto.orchestration import metadata_pb2
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import channel as channel_types
from tfx.types import channel_utils
from tfx.types import resolved_channel
from tfx.types import value_artifact
from tfx.utils import deprecation_utils
from tfx.utils import name_utils
from tfx.utils import typing_utils
from ml_metadata.proto import metadata_store_pb2
_PropertyPredicate = pipeline_pb2.PropertyPredicate
def _get_tfx_value(value: str) -> pipeline_pb2.Value:
"""Returns a TFX Value containing the provided string."""
return pipeline_pb2.Value(
field_value=data_types_utils.set_metadata_value(
metadata_store_pb2.Value(), value
)
)
def _compile_input_graph(
pipeline_ctx: compiler_context.PipelineContext,
tfx_node: base_node.BaseNode,
channel: resolved_channel.ResolvedChannel,
result: pipeline_pb2.NodeInputs,
) -> str:
"""Compiles ResolvedChannel.output_node as InputGraph and returns its ID."""
input_graph_id = pipeline_ctx.get_node_context(tfx_node).get_input_graph_key(
channel.output_node, channel.for_each_context)
if input_graph_id in result.input_graphs:
return input_graph_id
node_to_ids = {}
input_graph = result.input_graphs[input_graph_id]
def issue_node_id(prefix: str):
return prefix + str(len(node_to_ids) + 1)
def get_node_id(node: resolver_op.Node):
if node in node_to_ids:
return node_to_ids[node]
if isinstance(node, resolver_op.InputNode):
return compile_input_node(cast(resolver_op.InputNode, node))
elif isinstance(node, resolver_op.DictNode):
return compile_dict_node(cast(resolver_op.DictNode, node))
elif isinstance(node, resolver_op.OpNode):
return compile_op_node(cast(resolver_op.OpNode, node))
else:
raise NotImplementedError(
'Expected `node` to be one of InputNode, DictNode, or OpNode '
f'but got `{type(node).__name__}` type.')
def compile_input_node(input_node: resolver_op.InputNode):
node_id = issue_node_id(prefix='input_')
node_to_ids[input_node] = node_id
input_key = (
pipeline_ctx.get_node_context(tfx_node)
.get_input_key(input_node.wrapped))
_compile_input_spec(
pipeline_ctx=pipeline_ctx,
tfx_node=tfx_node,
input_key=input_key,
channel=input_node.wrapped,
hidden=True,
min_count=0,
result=result)
input_graph.nodes[node_id].input_node.input_key = input_key
input_graph.nodes[node_id].output_data_type = input_node.output_data_type
return node_id
def compile_dict_node(dict_node: resolver_op.DictNode):
node_id = issue_node_id(prefix='dict_')
node_to_ids[dict_node] = node_id
input_graph.nodes[node_id].dict_node.node_ids.update({
key: get_node_id(child_node)
for key, child_node in dict_node.nodes.items()
})
input_graph.nodes[node_id].output_data_type = (
pipeline_pb2.InputGraph.DataType.ARTIFACT_MULTIMAP)
return node_id
def compile_op_node(op_node: resolver_op.OpNode):
node_id = issue_node_id(prefix='op_')
node_to_ids[op_node] = node_id
op_node_ir = input_graph.nodes[node_id].op_node
if issubclass(op_node.op_type, resolver_op.ResolverOp):
op_node_ir.op_type = op_node.op_type.canonical_name
else:
op_node_ir.op_type = name_utils.get_full_name(
deprecation_utils.get_first_nondeprecated_class(op_node.op_type))
for n in op_node.args:
op_node_ir.args.add().node_id = get_node_id(n)
for k, v in op_node.kwargs.items():
data_types_utils.set_parameter_value(op_node_ir.kwargs[k].value, v)
input_graph.nodes[node_id].output_data_type = op_node.output_data_type
return node_id
input_graph.result_node = get_node_id(channel.output_node)
return input_graph_id
def _compile_channel_pb_contexts(
# TODO(b/264728226) Can flatten these args to make it more readable.
types_values_and_predicates: Iterable[
tuple[str, pipeline_pb2.Value, Optional[_PropertyPredicate]]
],
result: pipeline_pb2.InputSpec.Channel,
):
"""Adds contexts to the channel."""
for (
context_type,
context_value,
predicate,
) in types_values_and_predicates:
ctx = result.context_queries.add()
ctx.type.name = context_type
if context_value:
ctx.name.CopyFrom(context_value)
if predicate:
ctx.property_predicate.CopyFrom(predicate)
def _compile_channel_pb(
artifact_type: Type[types.Artifact],
pipeline_name: str,
node_id: str,
output_key: str,
result: pipeline_pb2.InputSpec.Channel,
):
"""Compile InputSpec.Channel with an artifact type and context filters."""
mlmd_artifact_type = artifact_type._get_artifact_type() # pylint: disable=protected-access
result.artifact_query.type.CopyFrom(mlmd_artifact_type)
result.artifact_query.type.ClearField('properties')
contexts_types_and_values = [(
constants.PIPELINE_CONTEXT_TYPE_NAME,
_get_tfx_value(pipeline_name),
None,
)]
if node_id:
contexts_types_and_values.append(
(
constants.NODE_CONTEXT_TYPE_NAME,
_get_tfx_value(
compiler_utils.node_context_name(pipeline_name, node_id)
),
None,
),
)
_compile_channel_pb_contexts(contexts_types_and_values, result)
if output_key:
result.output_key = output_key
def _construct_predicate(
predicate_names_and_values: Sequence[tuple[str, metadata_store_pb2.Value]],
) -> Optional[_PropertyPredicate]:
"""Constructs a PropertyPredicate from a list of name and value pairs."""
if not predicate_names_and_values:
return None
predicates = []
for name, predicate_value in predicate_names_and_values:
predicates.append(
_PropertyPredicate(
value_comparator=_PropertyPredicate.ValueComparator(
property_name=name,
op=_PropertyPredicate.ValueComparator.Op.EQ,
target_value=pipeline_pb2.Value(field_value=predicate_value),
is_custom_property=True,
)
)
)
def _make_and(lhs, rhs):
return _PropertyPredicate(
binary_logical_operator=_PropertyPredicate.BinaryLogicalOperator(
op=_PropertyPredicate.BinaryLogicalOperator.AND, lhs=lhs, rhs=rhs
)
)
if predicates:
return functools.reduce(_make_and, predicates)
def _compile_input_spec(
*,
pipeline_ctx: compiler_context.PipelineContext,
tfx_node: base_node.BaseNode,
input_key: str,
channel: channel_types.BaseChannel,
hidden: bool,
min_count: int,
result: pipeline_pb2.NodeInputs,
) -> None:
"""Compiles `BaseChannel` into `InputSpec` at `result.inputs[input_key]`.
Args:
pipeline_ctx: A `PipelineContext`.
tfx_node: A `BaseNode` instance from pipeline DSL.
input_key: An input key that the compiled `InputSpec` would be stored with.
channel: A `BaseChannel` instance to compile.
hidden: If true, this sets `InputSpec.hidden = True`. If the same channel
instances have been called multiple times with different `hidden` value,
then `hidden` will be `False`. In other words, if the channel is ever
compiled with `hidden=False`, it will ignore other `hidden=True`.
min_count: Minimum number of artifacts that should be resolved for this
input key. If min_count is not met during the input resolution, it is
considered as an error.
result: A `NodeInputs` proto to which the compiled result would be written.
"""
if input_key in result.inputs:
# Already compiled. This can happen during compiling another input channel
# from the same resolver function output.
if not hidden:
# Overwrite hidden = False even for already compiled channel, this is
# because we don't know the input should truly be hidden until the
# channel turns out not to be.
result.inputs[input_key].hidden = False
return
if channel in pipeline_ctx.channels:
# OutputChannel or PipelineInputChannel from the same pipeline has already
# compiled IR in context.channels
result.inputs[input_key].channels.append(pipeline_ctx.channels[channel])
elif isinstance(channel, channel_types.PipelineOutputChannel):
# This is the case when PipelineInputs uses pipeline.outputs where the
# pipeline is external (i.e. not a parent or sibling pipeline) thus
# pipeline run cannot be synced.
channel = cast(channel_types.PipelineOutputChannel, channel)
_compile_channel_pb(
artifact_type=channel.type,
pipeline_name=channel.pipeline.id,
node_id=channel.wrapped.producer_component_id,
output_key=channel.output_key,
result=result.inputs[input_key].channels.add(),
)
elif isinstance(channel, channel_types.ExternalPipelineChannel):
channel = cast(channel_types.ExternalPipelineChannel, channel)
result_input_channel = result.inputs[input_key].channels.add()
_compile_channel_pb(
artifact_type=channel.type,
pipeline_name=channel.pipeline_name,
node_id=channel.producer_component_id,
output_key=channel.output_key,
result=result_input_channel,
)
if channel.pipeline_run_id or channel.run_context_predicates:
predicate = _construct_predicate(channel.run_context_predicates)
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(
channel.pipeline_run_id if channel.pipeline_run_id else ''
),
predicate,
)],
result_input_channel,
)
if pipeline_ctx.pipeline.platform_config:
project_config = (
pipeline_ctx.pipeline.platform_config.project_platform_config
)
if (
channel.owner != project_config.owner
or channel.pipeline_name != project_config.project_name
):
config = metadata_pb2.MLMDServiceConfig(
owner=channel.owner,
name=channel.pipeline_name,
)
result_input_channel.metadata_connection_config.Pack(config)
else:
config = metadata_pb2.MLMDServiceConfig(
owner=channel.owner,
name=channel.pipeline_name,
)
result_input_channel.metadata_connection_config.Pack(config)
# Note that this path is *usually* not taken, as most output channels already
# exist in pipeline_ctx.channels, as they are added in after
# compiler._generate_input_spec_for_outputs is called.
# This path gets taken when a channel is copied, for example by
# `as_optional()`, as Channel uses `id` for a hash.
elif isinstance(channel, channel_types.OutputChannel):
channel = cast(channel_types.Channel, channel)
result_input_channel = result.inputs[input_key].channels.add()
_compile_channel_pb(
artifact_type=channel.type,
pipeline_name=pipeline_ctx.pipeline_info.pipeline_name,
node_id=channel.producer_component_id,
output_key=channel.output_key,
result=result_input_channel,
)
node_contexts = node_contexts_compiler.compile_node_contexts(
pipeline_ctx, tfx_node.id
)
contexts_to_add = []
for context_spec in node_contexts.contexts:
if context_spec.type.name == constants.PIPELINE_RUN_CONTEXT_TYPE_NAME:
contexts_to_add.append(
(constants.PIPELINE_RUN_CONTEXT_TYPE_NAME, context_spec.name, None)
)
_compile_channel_pb_contexts(contexts_to_add, result_input_channel)
elif isinstance(channel, channel_types.Channel):
channel = cast(channel_types.Channel, channel)
_compile_channel_pb(
artifact_type=channel.type,
pipeline_name=pipeline_ctx.pipeline_info.pipeline_name,
node_id=channel.producer_component_id,
output_key=channel.output_key,
result=result.inputs[input_key].channels.add(),
)
elif isinstance(channel, channel_types.UnionChannel):
channel = cast(channel_types.UnionChannel, channel)
mixed_inputs = result.inputs[input_key].mixed_inputs
mixed_inputs.method = pipeline_pb2.InputSpec.Mixed.Method.UNION
for sub_channel in channel.channels:
sub_key = (
pipeline_ctx.get_node_context(tfx_node).get_input_key(sub_channel))
mixed_inputs.input_keys.append(sub_key)
_compile_input_spec(
pipeline_ctx=pipeline_ctx,
tfx_node=tfx_node,
input_key=sub_key,
channel=sub_channel,
hidden=True,
min_count=0,
result=result)
elif isinstance(channel, resolved_channel.ResolvedChannel):
channel = cast(resolved_channel.ResolvedChannel, channel)
input_graph_ref = result.inputs[input_key].input_graph_ref
input_graph_ref.graph_id = _compile_input_graph(
pipeline_ctx, tfx_node, channel, result)
if channel.output_key:
input_graph_ref.key = channel.output_key
elif isinstance(channel, channel_utils.ChannelForTesting):
channel = cast(channel_utils.ChannelForTesting, channel)
# Access result.inputs[input_key] to create an empty `InputSpec`. If the
# testing channel does not point to static artifact IDs, empty `InputSpec`
# is enough for testing.
input_spec = result.inputs[input_key]
if channel.artifact_ids:
input_spec.static_inputs.artifact_ids.extend(channel.artifact_ids)
else:
raise NotImplementedError(
f'Node {tfx_node.id} got unsupported channel type {channel!r} for '
f'inputs[{input_key!r}].')
if hidden:
result.inputs[input_key].hidden = True
if min_count:
result.inputs[input_key].min_count = min_count
def _compile_conditionals(
context: compiler_context.PipelineContext,
tfx_node: base_node.BaseNode,
result: pipeline_pb2.NodeInputs,
) -> None:
"""Compiles conditionals attached to the BaseNode.
It also compiles the channels that each conditional predicate depends on. If
the channel already appears in the node inputs, reuses it. Otherwise, creates
an implicit hidden input.
Args:
context: A `PipelineContext`.
tfx_node: A `BaseNode` instance from pipeline DSL.
result: A `NodeInputs` proto to which the compiled result would be written.
"""
try:
contexts = context.dsl_context_registry.get_contexts(tfx_node)
except ValueError:
return
for dsl_context in contexts:
if not isinstance(dsl_context, conditional.CondContext):
continue
cond_context = cast(conditional.CondContext, dsl_context)
for channel in channel_utils.get_dependent_channels(cond_context.predicate):
# Since the channels here are *always* from a CWP, which we now set the
# key by default on for OutputChannel, we must re-create the input key if
# an output channel is used, otherwise the wrong key may be used by
# `get_input_key` (e.g. if the producer component is also used as data
# input to the component.)
# Note that this means we potentially have several inputs with identical
# artifact queries under the hood, which should be optimized away if we
# run into performance issues.
if isinstance(channel, channel_types.OutputChannel):
input_key = compiler_utils.implicit_channel_key(channel)
else:
input_key = context.get_node_context(tfx_node).get_input_key(channel)
_compile_input_spec(
pipeline_ctx=context,
tfx_node=tfx_node,
input_key=input_key,
channel=channel,
hidden=False,
min_count=1,
result=result,
)
cond_id = context.get_conditional_id(cond_context)
expr = channel_utils.encode_placeholder_with_channels(
cond_context.predicate, context.get_node_context(tfx_node).get_input_key
)
result.conditionals[cond_id].placeholder_expression.CopyFrom(expr)
def _compile_inputs_for_dynamic_properties(
context: compiler_context.PipelineContext,
tfx_node: base_node.BaseNode,
result: pipeline_pb2.NodeInputs,
) -> None:
"""Compiles additional InputSpecs used in dynamic properties.
Dynamic properties are the execution properties whose value comes from the
artifact value. Because of that, dynamic property resolution happens after
the input resolution at orchestrator, so input resolution should include the
resolved artifacts for the channel on which dynamic properties depend (thus
`_compile_channel(hidden=False)`).
Args:
context: A `PipelineContext`.
tfx_node: A `BaseNode` instance from pipeline DSL.
result: A `NodeInputs` proto to which the compiled result would be written.
"""
for key, exec_property in tfx_node.exec_properties.items():
if not isinstance(exec_property, placeholder.Placeholder):
continue
# Validate all the .future().value placeholders. Note that .future().uri is
# also allowed and doesn't need additional validation.
for p in exec_property.traverse():
if isinstance(p, artifact_placeholder._ArtifactValueOperator): # pylint: disable=protected-access
for channel in channel_utils.get_dependent_channels(p):
channel_type = channel.type # is_compatible() needs this variable.
if not typing_utils.is_compatible(
channel_type, Type[value_artifact.ValueArtifact]
):
raise ValueError(
'When you pass <channel>.future().value to an execution '
'property, the channel must be of a value artifact type '
f'(String, Float, ...). Got {channel_type.TYPE_NAME} in exec '
f'property {key!r} of node {tfx_node.id!r}.'
)
for channel in channel_utils.get_dependent_channels(exec_property):
_compile_input_spec(
pipeline_ctx=context,
tfx_node=tfx_node,
input_key=context.get_node_context(tfx_node).get_input_key(channel),
channel=channel,
hidden=False,
min_count=1,
result=result,
)
def _validate_min_count(
input_key: str,
min_count: int,
channel: channel_types.OutputChannel,
consumer_node: base_node.BaseNode,
) -> None:
"""Validates artifact min count against node execution options.
Note that the validation is not comprehensive. It only applies to components
in the same pipeline. Other min_count violations will be handled as node
failure at run time.
Args:
input_key: Artifact input key to be displayed in error messages.
min_count: Minimum artifact count to be set in InputSpec.
channel: OutputChannel used as an input to be compiled.
consumer_node: Node using the artifact as an input.
Raises:
ValueError: if min_count is invalid.
Returns:
None if the validation passes.
"""
producer_options = channel.producer_component.node_execution_options
if producer_options and producer_options.success_optional and min_count > 0:
raise ValueError(
f'Node({channel.producer_component_id}) is set to success_optional '
f'= True but its consumer Node({consumer_node.id}).inputs[{input_key}] '
'has min_count > 0. The consumer\'s input may need to be optional'
)
consumer_options = consumer_node.node_execution_options
if (
consumer_options
and consumer_options.trigger_strategy
in (
pipeline_pb2.NodeExecutionOptions.ALL_UPSTREAM_NODES_COMPLETED,
pipeline_pb2.NodeExecutionOptions.LAZILY_ALL_UPSTREAM_NODES_COMPLETED,
)
and min_count > 0
):
raise ValueError(
f'Node({consumer_node.id}) has trigger_strategy ='
f' {pipeline_pb2.NodeExecutionOptions.TriggerStrategy.Name(consumer_options.trigger_strategy)} but'
f" its inputs[{input_key}] has min_count > 0. The consumer's input may"
' need to be optional'
)
def compile_node_inputs(
context: compiler_context.PipelineContext,
tfx_node: base_node.BaseNode,
result: pipeline_pb2.NodeInputs,
) -> None:
"""Compile NodeInputs from BaseNode input channels."""
# Compile DSL node inputs.
for input_key, channel in tfx_node.inputs.items():
if compiler_utils.is_resolver(tfx_node):
min_count = 0
elif isinstance(tfx_node, pipeline.Pipeline):
pipeline_inputs_channel = tfx_node.inputs[input_key]
min_count = 0 if pipeline_inputs_channel.is_optional else 1
elif isinstance(tfx_node, base_component.BaseComponent):
spec_param = tfx_node.spec.INPUTS[input_key]
if (
spec_param.allow_empty_explicitly_set
and channel.is_optional is not None
and (spec_param.allow_empty != channel.is_optional)
):
raise ValueError(
f'Node {tfx_node.id} input channel {input_key} allow_empty is set'
f' to {spec_param.allow_empty} but the provided channel is'
f' {channel.is_optional}. If the component spec explicitly sets'
' allow_empty, then the channel must match.'
)
elif spec_param.allow_empty or channel.is_optional:
min_count = 0
else:
min_count = 1
else:
min_count = 1
if isinstance(channel, channel_types.OutputChannel):
_validate_min_count(
input_key=input_key,
min_count=min_count,
channel=channel,
consumer_node=tfx_node,
)
_compile_input_spec(
pipeline_ctx=context,
tfx_node=tfx_node,
input_key=input_key,
channel=channel,
hidden=False,
min_count=min_count,
result=result)
# Add implicit input channels that are used in conditionals.
_compile_conditionals(context, tfx_node, result)
# Add implicit input channels that are used in dynamic properties.
_compile_inputs_for_dynamic_properties(context, tfx_node, result)