def _model_fn()

in models/official/mask_rcnn/mask_rcnn_model.py [0:0]


def _model_fn(features, labels, mode, params, variable_filter_fn=None):
  """Model defination for the Mask-RCNN model based on ResNet.

  Args:
    features: the input image tensor and auxiliary information, such as
      `image_info` and `source_ids`. The image tensor has a shape of
      [batch_size, height, width, 3]. The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include score targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
  """
  if (mode == tf.estimator.ModeKeys.PREDICT or
      mode == tf.estimator.ModeKeys.EVAL):
    if ((params['include_groundtruth_in_features'] or
         mode == tf.estimator.ModeKeys.EVAL) and ('labels' in features)):
      # In include groundtruth for eval.
      labels = features['labels']

    if 'features' in features:
      features = features['features']
      # Otherwise, it is in export mode, the features is past in directly.

  if params['precision'] == 'bfloat16':
    with tf.tpu.bfloat16_scope():
      model_outputs = build_model_graph(features, labels,
                                        mode == tf.estimator.ModeKeys.TRAIN,
                                        params)
      model_outputs.update({
          'source_id': features['source_ids'],
          'image_info': features['image_info'],
      })
      def cast_outputs_to_float(d):
        for k, v in sorted(six.iteritems(d)):
          if isinstance(v, dict):
            cast_outputs_to_float(v)
          else:
            d[k] = tf.cast(v, tf.float32)
      cast_outputs_to_float(model_outputs)
  else:
    model_outputs = build_model_graph(features, labels,
                                      mode == tf.estimator.ModeKeys.TRAIN,
                                      params)
    model_outputs.update({
        'source_id': features['source_ids'],
        'image_info': features['image_info'],
    })

  # First check if it is in PREDICT or EVAL mode to fill out predictions.
  # Predictions are used during the eval step to generate metrics.
  predictions = {}
  if (mode == tf.estimator.ModeKeys.PREDICT or
      mode == tf.estimator.ModeKeys.EVAL):
    if 'orig_images' in features:
      model_outputs['orig_images'] = features['orig_images']
    if labels and params['include_groundtruth_in_features']:
      # Labels can only be embedded in predictions. The predition cannot output
      # dictionary as a value.
      predictions.update(labels)
    model_outputs.pop('fpn_features', None)
    predictions.update(model_outputs)
    # If we are doing PREDICT, we can return here.
    if mode == tf.estimator.ModeKeys.PREDICT:
      if params['use_tpu']:
        return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                 predictions=predictions)
      return tf.estimator.EstimatorSpec(mode=mode,
                                        predictions=predictions)

  # Set up training loss and learning rate.
  global_step = tf.train.get_or_create_global_step()
  if params['learning_rate_type'] == 'step':
    learning_rate = learning_rates.step_learning_rate_with_linear_warmup(
        global_step,
        params['init_learning_rate'],
        params['warmup_learning_rate'],
        params['warmup_steps'],
        params['learning_rate_levels'],
        params['learning_rate_steps'])
  elif params['learning_rate_type'] == 'cosine':
    learning_rate = learning_rates.cosine_learning_rate_with_linear_warmup(
        global_step,
        params['init_learning_rate'],
        params['warmup_learning_rate'],
        params['warmup_steps'],
        params['total_steps'])
  else:
    raise ValueError('Unsupported learning rate type: `{}`!'
                     .format(params['learning_rate_type']))
  # score_loss and box_loss are for logging. only total_loss is optimized.
  total_rpn_loss, rpn_score_loss, rpn_box_loss = losses.rpn_loss(
      model_outputs['rpn_score_outputs'], model_outputs['rpn_box_outputs'],
      labels, params)

  (total_fast_rcnn_loss, fast_rcnn_class_loss,
   fast_rcnn_box_loss) = losses.fast_rcnn_loss(
       model_outputs['class_outputs'], model_outputs['box_outputs'],
       model_outputs['class_targets'], model_outputs['box_targets'], params)
  # Only training has the mask loss. Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/model_builder.py  # pylint: disable=line-too-long
  if mode == tf.estimator.ModeKeys.TRAIN and params['include_mask']:
    mask_loss = losses.mask_rcnn_loss(
        model_outputs['mask_outputs'], model_outputs['mask_targets'],
        model_outputs['selected_class_targets'], params)
  else:
    mask_loss = 0.
  if variable_filter_fn and ('resnet' in params['backbone']):
    var_list = variable_filter_fn(tf.trainable_variables(),
                                  params['backbone'] + '/')
  else:
    var_list = tf.trainable_variables()
  l2_regularization_loss = params['l2_weight_decay'] * tf.add_n([
      tf.nn.l2_loss(v)
      for v in var_list
      if 'batch_normalization' not in v.name and 'bias' not in v.name
  ])
  total_loss = (total_rpn_loss + total_fast_rcnn_loss + mask_loss +
                l2_regularization_loss)

  host_call = None
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = create_optimizer(learning_rate, params)
    if params['use_tpu']:
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    scaffold_fn = None
    if params['warm_start_path']:

      def warm_start_scaffold_fn():
        logging.info(
            'model_fn warm start from: %s,', params['warm_start_path'])
        assignment_map = _build_assigment_map(
            optimizer,
            prefix=None,
            skip_variables_regex=params['skip_checkpoint_variables'])
        tf.train.init_from_checkpoint(params['warm_start_path'], assignment_map)
        return tf.train.Scaffold()

      scaffold_fn = warm_start_scaffold_fn

    elif params['checkpoint']:

      def backbone_scaffold_fn():
        """Loads pretrained model through scaffold function."""
        # Exclude all variable of optimizer.
        vars_to_load = _build_assigment_map(
            optimizer,
            prefix=params['backbone'] + '/',
            skip_variables_regex=params['skip_checkpoint_variables'])
        tf.train.init_from_checkpoint(params['checkpoint'], vars_to_load)
        if not vars_to_load:
          raise ValueError('Variables to load is empty.')
        return tf.train.Scaffold()

      scaffold_fn = backbone_scaffold_fn

    # Batch norm requires update_ops to be added as a train_op dependency.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
    if params['global_gradient_clip_ratio'] > 0:
      # Clips the gradients for training stability.
      # Refer: https://arxiv.org/abs/1211.5063
      with tf.name_scope('clipping'):
        old_grads, variables = zip(*grads_and_vars)
        num_weights = sum(
            g.shape.num_elements() for g in old_grads if g is not None)
        clip_norm = params['global_gradient_clip_ratio'] * math.sqrt(
            num_weights)
        logging.info(
            'Global clip norm set to %g for %d variables with %d elements.',
            clip_norm, sum(1 for g in old_grads if g is not None),
            num_weights)
        gradients, _ = tf.clip_by_global_norm(old_grads, clip_norm)
    else:
      gradients, variables = zip(*grads_and_vars)
    grads_and_vars = []
    # Special treatment for biases (beta is named as bias in reference model)
    # Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/optimizer.py#L113  # pylint: disable=line-too-long
    for grad, var in zip(gradients, variables):
      if grad is not None and ('beta' in var.name or 'bias' in var.name):
        grad = 2.0 * grad
      grads_and_vars.append((grad, var))

    with tf.control_dependencies(update_ops):
      train_op = optimizer.apply_gradients(
          grads_and_vars, global_step=global_step)

    if params['use_host_call']:
      def host_call_fn(global_step, total_loss, total_rpn_loss, rpn_score_loss,
                       rpn_box_loss, total_fast_rcnn_loss, fast_rcnn_class_loss,
                       fast_rcnn_box_loss, mask_loss, l2_regularization_loss,
                       learning_rate):
        """Training host call. Creates scalar summaries for training metrics.

        This function is executed on the CPU and should not directly reference
        any Tensors in the rest of the `model_fn`. To pass Tensors from the
        model to the `metric_fn`, provide as part of the `host_call`. See
        https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
        for more information.

        Arguments should match the list of `Tensor` objects passed as the second
        element in the tuple passed to `host_call`.

        Args:
          global_step: `Tensor with shape `[batch, ]` for the global_step.
          total_loss: `Tensor` with shape `[batch, ]` for the training loss.
          total_rpn_loss: `Tensor` with shape `[batch, ]` for the training RPN
            loss.
          rpn_score_loss: `Tensor` with shape `[batch, ]` for the training RPN
            score loss.
          rpn_box_loss: `Tensor` with shape `[batch, ]` for the training RPN
            box loss.
          total_fast_rcnn_loss: `Tensor` with shape `[batch, ]` for the
            training Mask-RCNN loss.
          fast_rcnn_class_loss: `Tensor` with shape `[batch, ]` for the
            training Mask-RCNN class loss.
          fast_rcnn_box_loss: `Tensor` with shape `[batch, ]` for the
            training Mask-RCNN box loss.
          mask_loss: `Tensor` with shape `[batch, ]` for the training Mask-RCNN
            mask loss.
          l2_regularization_loss: `Tensor` with shape `[batch, ]` for the
            regularization loss.
          learning_rate: `Tensor` with shape `[batch, ]` for the learning_rate.

        Returns:
          List of summary ops to run on the CPU host.
        """
        # Outfeed supports int32 but global_step is expected to be int64.
        global_step = tf.reduce_mean(global_step)
        # Host call fns are executed FLAGS.iterations_per_loop times after one
        # TPU loop is finished, setting max_queue value to the same as number of
        # iterations will make the summary writer only flush the data to storage
        # once per loop.
        with (tf2.summary.create_file_writer(
            params['model_dir'],
            max_queue=params['iterations_per_loop']).as_default()):
          with tf2.summary.record_if(True):
            tf2.summary.scalar(
                'total_loss', tf.reduce_mean(total_loss), step=global_step)
            tf2.summary.scalar(
                'total_rpn_loss', tf.reduce_mean(total_rpn_loss),
                step=global_step)
            tf2.summary.scalar(
                'rpn_score_loss', tf.reduce_mean(rpn_score_loss),
                step=global_step)
            tf2.summary.scalar(
                'rpn_box_loss', tf.reduce_mean(rpn_box_loss), step=global_step)
            tf2.summary.scalar(
                'total_fast_rcnn_loss', tf.reduce_mean(total_fast_rcnn_loss),
                step=global_step)
            tf2.summary.scalar(
                'fast_rcnn_class_loss', tf.reduce_mean(fast_rcnn_class_loss),
                step=global_step)
            tf2.summary.scalar(
                'fast_rcnn_box_loss', tf.reduce_mean(fast_rcnn_box_loss),
                step=global_step)
            if params['include_mask']:
              tf2.summary.scalar(
                  'mask_loss', tf.reduce_mean(mask_loss), step=global_step)
            tf2.summary.scalar(
                'l2_regularization_loss',
                tf.reduce_mean(l2_regularization_loss),
                step=global_step)
            tf2.summary.scalar(
                'learning_rate', tf.reduce_mean(learning_rate),
                step=global_step)

            return tf.summary.all_v2_summary_ops()

      # To log the loss, current learning rate, and epoch for Tensorboard, the
      # summary op needs to be run on the host CPU via host_call. host_call
      # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
      # dimension. These Tensors are implicitly concatenated to
      # [params['batch_size']].
      global_step_t = tf.reshape(global_step, [1])
      total_loss_t = tf.reshape(total_loss, [1])
      total_rpn_loss_t = tf.reshape(total_rpn_loss, [1])
      rpn_score_loss_t = tf.reshape(rpn_score_loss, [1])
      rpn_box_loss_t = tf.reshape(rpn_box_loss, [1])
      total_fast_rcnn_loss_t = tf.reshape(total_fast_rcnn_loss, [1])
      fast_rcnn_class_loss_t = tf.reshape(fast_rcnn_class_loss, [1])
      fast_rcnn_box_loss_t = tf.reshape(fast_rcnn_box_loss, [1])
      mask_loss_t = tf.reshape(mask_loss, [1])
      l2_regularization_loss = tf.reshape(l2_regularization_loss, [1])
      learning_rate_t = tf.reshape(learning_rate, [1])
      host_call = (host_call_fn,
                   [global_step_t, total_loss_t, total_rpn_loss_t,
                    rpn_score_loss_t, rpn_box_loss_t, total_fast_rcnn_loss_t,
                    fast_rcnn_class_loss_t, fast_rcnn_box_loss_t,
                    mask_loss_t, l2_regularization_loss, learning_rate_t])
  else:
    train_op = None
    scaffold_fn = None

  if params['use_tpu']:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        host_call=host_call,
        scaffold_fn=scaffold_fn)
  return tf.estimator.EstimatorSpec(
      mode=mode, loss=total_loss, train_op=train_op)