def apply()

in tfx/dsl/input_resolution/ops/latest_policy_model_op.py [0:0]


  def apply(self, input_dict: typing_utils.ArtifactMultiMap):
    """Finds the latest created model via a certain policy.

    The input_dict is expected to have the following format:

    {
        "model": [Model 1, Model 2, ...],
        "model_blessing": [ModelBlessing 1, ModelBlessing 2, ...],
        "model_infra_blessing": [ModelInfraBlessing 1, ...]
    }

    "model" is a required key. "model_blessing" and "model_infra_blessing" are
    optional keys. If "model_blessing" and/or "model_infra_blessing" are
    provided, then only their lineage w.r.t. the Model artifacts will be
    considered.

    Example usecases for specifying "model_blessing"/"model_infra_blessing"
    include: 1) Resolving inputs to a Pusher 2) Specifying ModelBlessing
    artifacts from a specific Evaluator, in cases where the pipeline has
    multiple Evaluators.

    Note that only the standard TFleX Model, ModelBlessing, ModelInfraBlessing,
    and ModelPush artifacts are supported.

    Args:
      input_dict: An input dict containing "model", "model_blessing",
        "model_infra_blessing" as keys and lists of Model, ModelBlessing, and
        ModelInfraBlessing artifacts as values, respectively.

    Returns:
      A dictionary containing the latest Model artifact, as well as the
      ModelBlessing, ModelInfraBlessing, and/or ModelPush based on the Policy.

      For example, for a LATEST_BLESSED policy, the following dict will be
      returned:
      {
        "model": [Model],
        "model_blessing": [ModelBlessing],
        "model_infra_blessing": [ModelInfraBlessing]
      }

      For a LATEST_PUSHED policy, the following dict will be returned:
      {
        "model": [Model],
        "model_push": [ModelPush]
      }

    Raises:
      InvalidArgument: If the models are not Model artifacts.
      SkipSignal: If raise_skip_signal is True and one of the following:
        1. The input_dict is empty.
        2. If no models are passed in.
        3. If input_dict contains "model_blessing" and/or "model_infra_blessing"
           as keys but have empty lists as values for both of them.
        4. No latest model was found that matches the policy.
    """
    if not input_dict:
      return self._raise_skip_signal_or_return_empty_dict(
          'The input dictionary is empty.'
      )

    _validate_input_dict(input_dict)

    if not input_dict[ops_utils.MODEL_KEY]:
      return self._raise_skip_signal_or_return_empty_dict(
          'The "model" key in the input dict contained no Model artifacts.'
      )

    # Sort the models from from latest created to oldest.
    models = input_dict.get(ops_utils.MODEL_KEY)
    models.sort(  # pytype: disable=attribute-error
        key=lambda a: (a.mlmd_artifact.create_time_since_epoch, a.id),
        reverse=True,
    )

    # Return the latest trained model if the policy is LATEST_EXPORTED.
    if self.policy == Policy.LATEST_EXPORTED:
      return {ops_utils.MODEL_KEY: [models[0]]}

    are_models_external = [m.is_external for m in models]
    if any(are_models_external) and not all(are_models_external):
      raise exceptions.InvalidArgument(
          'Inputs to the LastestPolicyModel are from both current pipeline and'
          ' external pipeline. LastestPolicyModel does not support such usage.'
      )
    if all(are_models_external):
      pipeline_assets = set([
          external_artifact_utils.get_pipeline_asset_from_external_id(
              m.mlmd_artifact.external_id
          )
          for m in models
      ])
      if len(pipeline_assets) != 1:
        raise exceptions.InvalidArgument(
            'Input models to the LastestPolicyModel are from multiple'
            ' pipelines. LastestPolicyModel does not support such usage.'
        )

    # If ModelBlessing and/or ModelInfraBlessing artifacts were included in
    # input_dict, then we will only consider those child artifacts.
    specifies_child_artifacts = (
        ops_utils.MODEL_BLESSSING_KEY in input_dict.keys()
        or ops_utils.MODEL_INFRA_BLESSING_KEY in input_dict.keys()
    )
    input_child_artifacts = input_dict.get(
        ops_utils.MODEL_BLESSSING_KEY, []
    ) + input_dict.get(ops_utils.MODEL_INFRA_BLESSING_KEY, [])

    input_child_artifact_ids = set()
    for a in input_child_artifacts:
      if a.is_external:
        input_child_artifact_ids.add(
            external_artifact_utils.get_id_from_external_id(
                a.mlmd_artifact.external_id
            )
        )
      else:
        input_child_artifact_ids.add(a.id)

    # If the ModelBlessing and ModelInfraBlessing lists are empty, then no
    # child artifacts can be considered and we raise a SkipSignal. This can
    # occur when a Model has been trained but not blessed yet, for example.
    if specifies_child_artifacts and not input_child_artifact_ids:
      return self._raise_skip_signal_or_return_empty_dict(
          '"model_blessing" and/or "model_infra_blessing" were specified as '
          'keys in the input dictionary, but contained no '
          'ModelBlessing/ModelInfraBlessing artifacts.'
      )

    # In MLMD, two artifacts are related by:
    #
    #       Event 1           Event 2
    # Model ------> Execution ------> Artifact B
    #
    # Artifact B can be:
    # 1. ModelBlessing output artifact from an Evaluator.
    # 2. ModelInfraBlessing output artifact from an InfraValidator.
    # 3. ModelPush output artifact from a Pusher.
    #
    # We query MLMD to get a list of candidate model artifact ids that have
    # a child artifact of type child_artifact_type. Note we perform batch
    # queries to reduce the number round trips to the database.

    # There could be multiple events with the same execution ID but different
    # artifact IDs (e.g. model and baseline_model passed to an Evaluator), so we
    # need to deduplicate the Model artifacts.
    deduped_models, model_artifact_ids = _dedpupe_model_artifacts(models)

    downstream_artifact_type_names_filter_query = q.to_sql_string([
        ops_utils.MODEL_BLESSING_TYPE_NAME,
        ops_utils.MODEL_INFRA_BLESSSING_TYPE_NAME,
        ops_utils.MODEL_PUSH_TYPE_NAME,
    ])
    input_child_artifact_ids_filter_query = q.to_sql_string(
        list(input_child_artifact_ids)
    )

    artifact_states_filter_query = (
        ops_utils.get_valid_artifact_states_filter_query(_VALID_ARTIFACT_STATES)
    )
    filter_query = (
        f'type IN {downstream_artifact_type_names_filter_query} AND '
        f'{artifact_states_filter_query}'
    )

    if input_child_artifact_ids and specifies_child_artifacts:
      filter_query = (
          f'id IN {input_child_artifact_ids_filter_query} AND {filter_query}'
      )

    if self.policy == Policy.LATEST_PUSHED:
      event_input_key = ops_utils.MODEL_EXPORT_KEY
    else:
      event_input_key = ops_utils.MODEL_KEY

    # Define event filter for paths filtering, the logic is:
    # An event considered as valid to be included the path if:
    # 1. It's an input event and not connected to a Model artifact.
    # 2. It's an input event with event_input_key and connected to a Model
    # artifact.
    # 3. It's an OUTPUT (not PENDING_OUTPUT) event.
    def event_filter(event):
      if event_lib.is_valid_input_event(event):
        if event.artifact_id in model_artifact_ids:
          return event_lib.is_valid_input_event(event, event_input_key)
        else:
          return True
      else:
        return event_lib.is_valid_output_event(event)

    mlmd_resolver = metadata_resolver.MetadataResolver(
        self.context.store,
        mlmd_connection_manager=self.context.mlmd_connection_manager,
    )
    # Populate the ModelRelations associated with each Model artifact and its
    # children.
    model_relations_by_model_identifier = collections.defaultdict(
        ModelRelations
    )
    artifact_type_by_name: Dict[str, metadata_store_pb2.ArtifactType] = {}

    # Split `model_artifact_ids` into batches with batch size = 100 while
    # fetching downstream artifacts, because
    # `get_downstream_artifacts_by_artifact_ids()` supports at most 100 ids
    # as starting artifact ids.
    for id_index in range(0, len(deduped_models), ops_utils.BATCH_SIZE):
      batch_model_artifacts = deduped_models[
          id_index : id_index + ops_utils.BATCH_SIZE
      ]
      # Set `max_num_hops` to 50, which should be enough for this use case.
      batch_downstream_artifacts_and_types_by_model_identifier = (
          mlmd_resolver.get_downstream_artifacts_by_artifacts(
              batch_model_artifacts,
              max_num_hops=ops_utils.LATEST_POLICY_MODEL_OP_MAX_NUM_HOPS,
              filter_query=filter_query,
              event_filter=event_filter,
          )
      )

      for (
          model_identifier,
          artifacts_and_types,
      ) in batch_downstream_artifacts_and_types_by_model_identifier.items():
        for downstream_artifact, artifact_type in artifacts_and_types:
          artifact_type_by_name[artifact_type.name] = artifact_type
          model_relations_by_model_identifier[
              model_identifier
          ].add_downstream_artifact(downstream_artifact)

    # Find the latest model and ModelRelations that meets the Policy.
    result = {}
    for model in models:
      identifier = external_artifact_utils.identifier(model)
      model_relations = model_relations_by_model_identifier[identifier]
      if model_relations.meets_policy(self.policy):
        result[ops_utils.MODEL_KEY] = [model]
        break
    else:
      return self._raise_skip_signal_or_return_empty_dict(
          f'No model found that meets the Policy {Policy(self.policy).name}'
      )

    return _build_result_dictionary(
        result, model_relations, self.policy, artifact_type_by_name
    )