def update_eval_config_with_defaults()

in tensorflow_model_analysis/utils/config_util.py [0:0]


def update_eval_config_with_defaults(
    eval_config: config_pb2.EvalConfig,
    maybe_add_baseline: Optional[bool] = None,
    maybe_remove_baseline: Optional[bool] = None,
    has_baseline: Optional[bool] = False,
    rubber_stamp: Optional[bool] = False) -> config_pb2.EvalConfig:
  """Returns a new config with default settings applied.

  a) Add or remove a model_spec according to "has_baseline".
  b) Fix the model names (model_spec.name) to tfma.CANDIDATE_KEY and
     tfma.BASELINE_KEY.
  c) Update the metrics_specs with the fixed model name.

  Args:
    eval_config: Original eval config.
    maybe_add_baseline: DEPRECATED. True to add a baseline ModelSpec to the
      config as a copy of the candidate ModelSpec that should already be
      present. This is only applied if a single ModelSpec already exists in the
      config and that spec doesn't have a name associated with it. When applied
      the model specs will use the names tfma.CANDIDATE_KEY and
      tfma.BASELINE_KEY. Only one of maybe_add_baseline or maybe_remove_baseline
      should be used.
    maybe_remove_baseline: DEPRECATED. True to remove a baseline ModelSpec from
      the config if it already exists. Removal of the baseline also removes any
      change thresholds. Only one of maybe_add_baseline or maybe_remove_baseline
      should be used.
    has_baseline: True to add a baseline ModelSpec to the config as a copy of
      the candidate ModelSpec that should already be present. This is only
      applied if a single ModelSpec already exists in the config and that spec
      doesn't have a name associated with it. When applied the model specs will
      use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. False to remove a
      baseline ModelSpec from the config if it already exists. Removal of the
      baseline also removes any change thresholds. Only one of has_baseline or
      maybe_remove_baseline should be used.
    rubber_stamp: True if this model is being rubber stamped. When a model is
      rubber stamped diff thresholds will be ignored if an associated baseline
      model is not passed.

  Raises:
    RuntimeError: on missing baseline model for non-rubberstamp cases.
  """
  if (not has_baseline and has_change_threshold(eval_config) and
      not rubber_stamp):
    # TODO(b/173657964): Raise an error instead of logging an error.
    raise RuntimeError(
        'There are change thresholds, but the baseline is missing. '
        'This is allowed only when rubber stamping (first run).')

  updated_config = config_pb2.EvalConfig()
  updated_config.CopyFrom(eval_config)
  # if user requests CIs but doesn't set method, use JACKKNIFE
  if (eval_config.options.compute_confidence_intervals.value and
      eval_config.options.confidence_intervals.method ==
      config_pb2.ConfidenceIntervalOptions.UNKNOWN_CONFIDENCE_INTERVAL_METHOD):
    updated_config.options.confidence_intervals.method = (
        config_pb2.ConfidenceIntervalOptions.JACKKNIFE)
  if maybe_add_baseline and maybe_remove_baseline:
    raise ValueError('only one of maybe_add_baseline and maybe_remove_baseline '
                     'should be used')
  if maybe_add_baseline or maybe_remove_baseline:
    logging.warning(
        """"maybe_add_baseline" and "maybe_remove_baseline" are deprecated,
        please use "has_baseline" instead.""")
    if has_baseline:
      raise ValueError(
          """"maybe_add_baseline" and "maybe_remove_baseline" are ignored if
          "has_baseline" is set.""")
  if has_baseline is not None:
    if has_baseline:
      maybe_add_baseline = True
    else:
      maybe_remove_baseline = True

  # Has a baseline model.
  if (maybe_add_baseline and len(updated_config.model_specs) == 1 and
      not updated_config.model_specs[0].name):
    baseline = updated_config.model_specs.add()
    baseline.CopyFrom(updated_config.model_specs[0])
    baseline.name = constants.BASELINE_KEY
    baseline.is_baseline = True
    updated_config.model_specs[0].name = constants.CANDIDATE_KEY
    logging.info(
        'Adding default baseline ModelSpec based on the candidate ModelSpec '
        'provided. The candidate model will be called "%s" and the baseline '
        'will be called "%s": updated_config=\n%s', constants.CANDIDATE_KEY,
        constants.BASELINE_KEY, updated_config)

  # Does not have a baseline.
  if maybe_remove_baseline:
    tmp_model_specs = []
    for model_spec in updated_config.model_specs:
      if not model_spec.is_baseline:
        tmp_model_specs.append(model_spec)
    del updated_config.model_specs[:]
    updated_config.model_specs.extend(tmp_model_specs)
    for metrics_spec in updated_config.metrics_specs:
      for metric in metrics_spec.metrics:
        if metric.threshold.ByteSize():
          metric.threshold.ClearField('change_threshold')
        for per_slice_threshold in metric.per_slice_thresholds:
          if per_slice_threshold.threshold.ByteSize():
            per_slice_threshold.threshold.ClearField('change_threshold')
        for cross_slice_threshold in metric.cross_slice_thresholds:
          if cross_slice_threshold.threshold.ByteSize():
            cross_slice_threshold.threshold.ClearField('change_threshold')
      for threshold in metrics_spec.thresholds.values():
        if threshold.ByteSize():
          threshold.ClearField('change_threshold')
      for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
        for per_slice_threshold in per_slice_thresholds.thresholds:
          if per_slice_threshold.threshold.ByteSize():
            per_slice_threshold.threshold.ClearField('change_threshold')
      for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values(
      ):
        for cross_slice_threshold in cross_slice_thresholds.thresholds:
          if cross_slice_threshold.threshold.ByteSize():
            cross_slice_threshold.threshold.ClearField('change_threshold')
    logging.info(
        'Request was made to ignore the baseline ModelSpec and any change '
        'thresholds. This is likely because a baseline model was not provided: '
        'updated_config=\n%s', updated_config)

  if not updated_config.model_specs:
    updated_config.model_specs.add()

  model_names = []
  for spec in updated_config.model_specs:
    model_names.append(spec.name)
  if len(model_names) == 1 and model_names[0]:
    logging.info(
        'ModelSpec name "%s" is being ignored and replaced by "" because a '
        'single ModelSpec is being used', model_names[0])
    updated_config.model_specs[0].name = ''
    model_names = ['']
  for spec in updated_config.metrics_specs:
    if not spec.model_names:
      spec.model_names.extend(model_names)
    elif len(model_names) == 1:
      del spec.model_names[:]
      spec.model_names.append('')

  return updated_config