def _model_fn()

in models/official/retinanet/retinanet_model.py [0:0]


def _model_fn(features, labels, mode, params, model, use_tpu_estimator_spec,
              variable_filter_fn=None):
  """Model defination for the RetinaNet model based on ResNet.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in dataloader.py
    mode: the mode of TPUEstimator/Estimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the RetinaNet model outputs class logits and box regression outputs.
    use_tpu_estimator_spec: Whether to use TPUEstimatorSpec or EstimatorSpec.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """

  # In predict mode features is a dict with input as value of the 'inputs'.
  image_info = None
  if (mode == tf.estimator.ModeKeys.PREDICT
      and isinstance(features, dict) and 'inputs' in features):
    image_info = features['image_info']
    labels = None
    if 'labels' in features:
      labels = features['labels']
    features = features['inputs']

  def _model_outputs():
    return model(
        features,
        min_level=params['min_level'],
        max_level=params['max_level'],
        num_classes=params['num_classes'],
        num_anchors=len(params['aspect_ratios'] * params['num_scales']),
        resnet_depth=params['resnet_depth'],
        is_training_bn=params['is_training_bn'])

  if params['use_bfloat16']:
    with contrib_tpu.bfloat16_scope():
      cls_outputs, box_outputs = _model_outputs()
      levels = cls_outputs.keys()
      for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
  else:
    cls_outputs, box_outputs = _model_outputs()
    levels = cls_outputs.keys()

  # First check if it is in PREDICT mode.
  if mode == tf.estimator.ModeKeys.PREDICT:
    # Postprocess on host; memory layout for NMS on TPU is very inefficient.
    def _predict_postprocess_wrapper(args):
      return _predict_postprocess(*args)

    predictions = contrib_tpu.outside_compilation(
        _predict_postprocess_wrapper,
        (cls_outputs, box_outputs, labels, params))

    # Include resizing information on prediction output to help bbox drawing.
    if image_info is not None:
      predictions.update({
          'image_info': tf.identity(image_info, 'ImageInfo'),
      })

    return contrib_tpu.TPUEstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions)

  # Load pretrained model from checkpoint.
  if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

    def scaffold_fn():
      """Loads pretrained model through scaffold function."""
      tf.train.init_from_checkpoint(params['resnet_checkpoint'], {
          '/': 'resnet%s/' % params['resnet_depth'],
      })
      return tf.train.Scaffold()
  else:
    scaffold_fn = None

  # Set up training loss and learning rate.
  update_learning_rate_schedule_parameters(params)
  global_step = tf.train.get_global_step()
  learning_rate = learning_rate_schedule(
      params['adjusted_learning_rate'], params['lr_warmup_init'],
      params['lr_warmup_step'], params['first_lr_drop_step'],
      params['second_lr_drop_step'], global_step)
  # cls_loss and box_loss are for logging. only total_loss is optimized.
  total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs,
                                                  labels, params)
  total_loss += _WEIGHT_DECAY * tf.add_n([
      tf.nn.l2_loss(v)
      for v in tf.trainable_variables()
      if 'batch_normalization' not in v.name
  ])

  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.MomentumOptimizer(
        learning_rate, momentum=params['momentum'])
    if params['use_tpu']:
      optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
    else:
      if params['auto_mixed_precision']:
        optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
            optimizer)

    # Batch norm requires `update_ops` to be executed alongside `train_op`.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    var_list = variable_filter_fn(
        tf.trainable_variables(),
        params['resnet_depth']) if variable_filter_fn else None

    minimize_op = optimizer.minimize(total_loss, global_step, var_list=var_list)
    train_op = tf.group(minimize_op, update_ops)

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:
    def metric_fn(**kwargs):
      """Returns a dictionary that has the evaluation metrics."""
      batch_size = params['batch_size']
      eval_anchors = anchors.Anchors(
          params['min_level'], params['max_level'], params['num_scales'],
          params['aspect_ratios'], params['anchor_scale'], params['image_size'])
      anchor_labeler = anchors.AnchorLabeler(eval_anchors,
                                             params['num_classes'])
      cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
      box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
      coco_metrics = coco_metric_fn(batch_size, anchor_labeler,
                                    params['val_json_file'], **kwargs)

      # Add metrics to output.
      output_metrics = {
          'cls_loss': cls_loss,
          'box_loss': box_loss,
      }
      output_metrics.update(coco_metrics)
      return output_metrics

    cls_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(cls_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])
    box_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(box_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])
    metric_fn_inputs = {
        'cls_loss_repeat': cls_loss_repeat,
        'box_loss_repeat': box_loss_repeat,
        'source_ids': labels['source_ids'],
        'groundtruth_data': labels['groundtruth_data'],
        'image_scales': labels['image_scales'],
    }
    add_metric_fn_inputs(params, cls_outputs, box_outputs, metric_fn_inputs)
    eval_metrics = (metric_fn, metric_fn_inputs)

  if use_tpu_estimator_spec:
    return contrib_tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        scaffold_fn=scaffold_fn)
  else:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=total_loss,
        # TODO(rostam): Fix bug to get scaffold working.
        # scaffold=scaffold_fn(),
        train_op=train_op)