in tfx/dsl/input_resolution/ops/latest_policy_model_op.py [0:0]
def apply(self, input_dict: typing_utils.ArtifactMultiMap):
"""Finds the latest created model via a certain policy.
The input_dict is expected to have the following format:
{
"model": [Model 1, Model 2, ...],
"model_blessing": [ModelBlessing 1, ModelBlessing 2, ...],
"model_infra_blessing": [ModelInfraBlessing 1, ...]
}
"model" is a required key. "model_blessing" and "model_infra_blessing" are
optional keys. If "model_blessing" and/or "model_infra_blessing" are
provided, then only their lineage w.r.t. the Model artifacts will be
considered.
Example usecases for specifying "model_blessing"/"model_infra_blessing"
include: 1) Resolving inputs to a Pusher 2) Specifying ModelBlessing
artifacts from a specific Evaluator, in cases where the pipeline has
multiple Evaluators.
Note that only the standard TFleX Model, ModelBlessing, ModelInfraBlessing,
and ModelPush artifacts are supported.
Args:
input_dict: An input dict containing "model", "model_blessing",
"model_infra_blessing" as keys and lists of Model, ModelBlessing, and
ModelInfraBlessing artifacts as values, respectively.
Returns:
A dictionary containing the latest Model artifact, as well as the
ModelBlessing, ModelInfraBlessing, and/or ModelPush based on the Policy.
For example, for a LATEST_BLESSED policy, the following dict will be
returned:
{
"model": [Model],
"model_blessing": [ModelBlessing],
"model_infra_blessing": [ModelInfraBlessing]
}
For a LATEST_PUSHED policy, the following dict will be returned:
{
"model": [Model],
"model_push": [ModelPush]
}
Raises:
InvalidArgument: If the models are not Model artifacts.
SkipSignal: If raise_skip_signal is True and one of the following:
1. The input_dict is empty.
2. If no models are passed in.
3. If input_dict contains "model_blessing" and/or "model_infra_blessing"
as keys but have empty lists as values for both of them.
4. No latest model was found that matches the policy.
"""
if not input_dict:
return self._raise_skip_signal_or_return_empty_dict(
'The input dictionary is empty.'
)
_validate_input_dict(input_dict)
if not input_dict[ops_utils.MODEL_KEY]:
return self._raise_skip_signal_or_return_empty_dict(
'The "model" key in the input dict contained no Model artifacts.'
)
# Sort the models from from latest created to oldest.
models = input_dict.get(ops_utils.MODEL_KEY)
models.sort( # pytype: disable=attribute-error
key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id),
reverse=True,
)
# Return the latest trained model if the policy is LATEST_EXPORTED.
if self.policy == Policy.LATEST_EXPORTED:
return {ops_utils.MODEL_KEY: [models[0]]}
are_models_external = [m.is_external for m in models]
if any(are_models_external) and not all(are_models_external):
raise exceptions.InvalidArgument(
'Inputs to the LastestPolicyModel are from both current pipeline and'
' external pipeline. LastestPolicyModel does not support such usage.'
)
if all(are_models_external):
pipeline_assets = set([
external_artifact_utils.get_pipeline_asset_from_external_id(
m.mlmd_artifact.external_id
)
for m in models
])
if len(pipeline_assets) != 1:
raise exceptions.InvalidArgument(
'Input models to the LastestPolicyModel are from multiple'
' pipelines. LastestPolicyModel does not support such usage.'
)
# If ModelBlessing and/or ModelInfraBlessing artifacts were included in
# input_dict, then we will only consider those child artifacts.
specifies_child_artifacts = (
ops_utils.MODEL_BLESSSING_KEY in input_dict.keys()
or ops_utils.MODEL_INFRA_BLESSING_KEY in input_dict.keys()
)
input_child_artifacts = input_dict.get(
ops_utils.MODEL_BLESSSING_KEY, []
) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])
input_child_artifact_ids = set()
for a in input_child_artifacts:
if a.is_external:
input_child_artifact_ids.add(
external_artifact_utils.get_id_from_external_id(
a.mlmd_artifact.external_id
)
)
else:
input_child_artifact_ids.add(a.id)
# If the ModelBlessing and ModelInfraBlessing lists are empty, then no
# child artifacts can be considered and we raise a SkipSignal. This can
# occur when a Model has been trained but not blessed yet, for example.
if specifies_child_artifacts and not input_child_artifact_ids:
return self._raise_skip_signal_or_return_empty_dict(
'"model_blessing" and/or "model_infra_blessing" were specified as '
'keys in the input dictionary, but contained no '
'ModelBlessing/ModelInfraBlessing artifacts.'
)
# In MLMD, two artifacts are related by:
#
# Event 1 Event 2
# Model ------> Execution ------> Artifact B
#
# Artifact B can be:
# 1. ModelBlessing output artifact from an Evaluator.
# 2. ModelInfraBlessing output artifact from an InfraValidator.
# 3. ModelPush output artifact from a Pusher.
#
# We query MLMD to get a list of candidate model artifact ids that have
# a child artifact of type child_artifact_type. Note we perform batch
# queries to reduce the number round trips to the database.
# There could be multiple events with the same execution ID but different
# artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
# need to deduplicate the Model artifacts.
deduped_models, model_artifact_ids = _dedpupe_model_artifacts(models)
downstream_artifact_type_names_filter_query = q.to_sql_string([
ops_utils.MODEL_BLESSING_TYPE_NAME,
ops_utils.MODEL_INFRA_BLESSSING_TYPE_NAME,
ops_utils.MODEL_PUSH_TYPE_NAME,
])
input_child_artifact_ids_filter_query = q.to_sql_string(
list(input_child_artifact_ids)
)
artifact_states_filter_query = (
ops_utils.get_valid_artifact_states_filter_query(_VALID_ARTIFACT_STATES)
)
filter_query = (
f'type IN {downstream_artifact_type_names_filter_query} AND '
f'{artifact_states_filter_query}'
)
if input_child_artifact_ids and specifies_child_artifacts:
filter_query = (
f'id IN {input_child_artifact_ids_filter_query} AND {filter_query}'
)
if self.policy == Policy.LATEST_PUSHED:
event_input_key = ops_utils.MODEL_EXPORT_KEY
else:
event_input_key = ops_utils.MODEL_KEY
# Define event filter for paths filtering, the logic is:
# An event considered as valid to be included the path if:
# 1. It's an input event and not connected to a Model artifact.
# 2. It's an input event with event_input_key and connected to a Model
# artifact.
# 3. It's an OUTPUT (not PENDING_OUTPUT) event.
def event_filter(event):
if event_lib.is_valid_input_event(event):
if event.artifact_id in model_artifact_ids:
return event_lib.is_valid_input_event(event, event_input_key)
else:
return True
else:
return event_lib.is_valid_output_event(event)
mlmd_resolver = metadata_resolver.MetadataResolver(
self.context.store,
mlmd_connection_manager=self.context.mlmd_connection_manager,
)
# Populate the ModelRelations associated with each Model artifact and its
# children.
model_relations_by_model_identifier = collections.defaultdict(
ModelRelations
)
artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}
# Split `model_artifact_ids` into batches with batch size = 100 while
# fetching downstream artifacts, because
# `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
# as starting artifact ids.
for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
batch_model_artifacts = deduped_models[
id_index : id_index + ops_utils.BATCH_SIZE
]
# Set `max_num_hops` to 50, which should be enough for this use case.
batch_downstream_artifacts_and_types_by_model_identifier = (
mlmd_resolver.get_downstream_artifacts_by_artifacts(
batch_model_artifacts,
max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
filter_query=filter_query,
event_filter=event_filter,
)
)
for (
model_identifier,
artifacts_and_types,
) in batch_downstream_artifacts_and_types_by_model_identifier.items():
for downstream_artifact, artifact_type in artifacts_and_types:
artifact_type_by_name[artifact_type.name] = artifact_type
model_relations_by_model_identifier[
model_identifier
].add_downstream_artifact(downstream_artifact)
# Find the latest model and ModelRelations that meets the Policy.
result = {}
for model in models:
identifier = external_artifact_utils.identifier(model)
model_relations = model_relations_by_model_identifier[identifier]
if model_relations.meets_policy(self.policy):
result[ops_utils.MODEL_KEY] = [model]
break
else:
return self._raise_skip_signal_or_return_empty_dict(
f'No model found that meets the Policy {Policy(self.policy).name}'
)
return _build_result_dictionary(
result, model_relations, self.policy, artifact_type_by_name
)