def multilevel_crop_and_resize()

in models/official/detection/ops/spatial_transform_ops.py [0:0]


def multilevel_crop_and_resize(features,
                               boxes,
                               output_size=7,
                               use_einsum_gather=False):
  """Crop and resize on multilevel feature pyramid.

  Generate the (output_size, output_size) set of pixels for each input box
  by first locating the box into the correct feature level, and then cropping
  and resizing it using the correspoding feature map of that level.

  Here is the step-by-step algorithm with use_einsum_gather=True:
  1. Compute sampling points and their four neighbors for each output points.
     Each box is mapped to [output_size, output_size] points.
     Each output point is averaged among #sampling_raitio^2 points.
     Each sampling point is computed using bilinear
     interpolation of its four neighboring points on the feature map.
  2. Gather output points seperately for each level. Gather and computation of
     output points are done for the boxes mapped to this level only.
     2.1. Compute indices of four neighboring point of each sampling
          point for x and y seperately of shape
          [batch_size, num_boxes, output_size, 2].
     2.2. Compute the interpolation kernel for axis x and y seperately of
          shape [batch_size, num_boxes, output_size, 2, 1].
     2.3. The features are colleced into a
          [batch_size, num_boxes, output_size, output_size, num_filters]
          Tensor.
          Instead of a one-step algorithm, a two-step approach is used.
          That is, first, an intermediate output is stored with a shape of
          [batch_size, num_boxes, output_size, width, num_filters];
          second, the final output is produced with a shape of
          [batch_size, num_boxes, output_size, output_size, num_filters].

          Blinear interpolation is done during the two step gather:
          f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
                                [f10, f11]]
          [[f00, f01],
           [f10, f11]] = tf.einsum(tf.einsum(features, y_one_hot), x_one_hot)
          where [hy, ly] and [hx, lx] are the bilinear interpolation kernel.

          Note:
            a. Use one_hot with einsum to replace gather;
            b. Bilinear interpolation and averaging of
               multiple sampling points are fused into the one_hot vector.

  Args:
    features: A dictionary with key as pyramid level and value as features. The
      features are in shape of [batch_size, height_l, width_l, num_filters].
    boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents
      a box with [y1, x1, y2, x2] in un-normalized coordinates.
    output_size: A scalar to indicate the output crop size.
    use_einsum_gather: use einsum to replace gather or not. Replacing einsum
      with gather can improve performance when feature size is not large, einsum
      is friendly with model partition as well. Gather's performance is better
      when feature size is very large and there are multiple box levels.

  Returns:
    A 5-D tensor representing feature crop of shape
    [batch_size, num_boxes, output_size, output_size, num_filters].
  """

  with tf.name_scope('multilevel_crop_and_resize'):
    levels = list(features.keys())
    min_level = min(levels)
    max_level = max(levels)
    batch_size, max_feature_height, max_feature_width, num_filters = (
        features[min_level].get_shape().as_list())
    if batch_size is None:
      batch_size = tf.shape(features[min_level])[0]
    _, num_boxes, _ = boxes.get_shape().as_list()

    # Assigns boxes to the right level.
    box_width = boxes[:, :, 3] - boxes[:, :, 1]
    box_height = boxes[:, :, 2] - boxes[:, :, 0]
    areas_sqrt = tf.sqrt(box_height * box_width)
    levels = tf.cast(
        tf.floordiv(tf.log(tf.div(areas_sqrt, 224.0)), tf.log(2.0)) + 4.0,
        dtype=tf.int32)
    # Maps levels between [min_level, max_level].
    levels = tf.minimum(max_level, tf.maximum(levels, min_level))

    # Projects box location and sizes to corresponding feature levels.
    scale_to_level = tf.cast(
        tf.pow(tf.constant(2.0), tf.cast(levels, tf.float32)),
        dtype=boxes.dtype)
    boxes /= tf.expand_dims(scale_to_level, axis=2)
    box_width /= scale_to_level
    box_height /= scale_to_level
    boxes = tf.concat([
        boxes[:, :, 0:2],
        tf.expand_dims(box_height, -1),
        tf.expand_dims(box_width, -1)
    ],
                      axis=-1)

    if use_einsum_gather:

      def two_step_gather_per_level(features_level, mask):
        """Performs two-step gather using einsum for every level of features."""
        (_, feature_height, feature_width,
         _) = features_level.get_shape().as_list()
        boundaries = tf.tile(
            tf.expand_dims(
                tf.expand_dims([feature_height, feature_width], 0), 0),
            [batch_size, num_boxes, 1])
        boundaries = tf.cast(boundaries, boxes.dtype)
        kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = compute_grid_positions(
            boxes, boundaries, output_size, sample_offset=0.5)

        # shape is:
        # [batch_size, num_boxes, output_size, 2, spatial_size]
        box_grid_y_one_hot, box_grid_x_one_hot = get_grid_one_hot(
            box_gridy0y1, box_gridx0x1, feature_height, feature_width)

        # # shape is [batch_size, num_boxes, output_size, spatial_size]
        box_grid_y_weight = tf.reduce_sum(
            tf.multiply(box_grid_y_one_hot, kernel_y), axis=-2)
        box_grid_x_weight = tf.reduce_sum(
            tf.multiply(box_grid_x_one_hot, kernel_x), axis=-2)

        # shape is [batch_size, num_boxes, output_size, width, feature]
        y_outputs = tf.einsum(
            'bhwf,bnyh->bnywf', features_level,
            tf.cast(box_grid_y_weight, dtype=features_level.dtype))

        # shape is [batch_size, num_boxes, output_size, output_size, feature]
        x_outputs = tf.einsum(
            'bnywf,bnxw->bnyxf', y_outputs,
            tf.cast(box_grid_x_weight, dtype=features_level.dtype))

        outputs = tf.where(
            tf.equal(mask, tf.zeros_like(mask)), tf.zeros_like(x_outputs),
            x_outputs)
        return outputs

      features_per_box = tf.zeros(
          [batch_size, num_boxes, output_size, output_size, num_filters],
          dtype=features[min_level].dtype)
      for level in range(min_level, max_level + 1):
        level_equal = tf.equal(levels, level)
        mask = tf.tile(
            tf.reshape(level_equal, [batch_size, num_boxes, 1, 1, 1]),
            [1, 1, output_size, output_size, num_filters])
        features_per_box += two_step_gather_per_level(features[level], mask)

      return features_per_box

    # Stack feature pyramid into a features_all of shape
    # [batch_size, levels, height, width, num_filters].
    features_all = []
    feature_heights = []
    feature_widths = []
    for level in range(min_level, max_level + 1):
      shape = features[level].get_shape().as_list()
      feature_heights.append(shape[1])
      feature_widths.append(shape[2])
      # Concat tensor of [batch_size, height_l * width_l, num_filters] for each
      # levels.
      features_all.append(
          tf.reshape(features[level], [batch_size, -1, num_filters]))
    features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters])

    # Calculate height_l * width_l for each level.
    level_dim_sizes = [
        feature_widths[i] * feature_heights[i]
        for i in range(len(feature_widths))
    ]
    # level_dim_offsets is accumulated sum of level_dim_size.
    level_dim_offsets = [0]
    for i in range(len(feature_widths) - 1):
      level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
    batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
    level_dim_offsets = tf.constant(level_dim_offsets, tf.int32)
    height_dim_sizes = tf.constant(feature_widths, tf.int32)

    # Maps levels to [0, max_level-min_level].
    levels -= min_level
    level_strides = tf.pow([[2.0]], tf.cast(levels, tf.float32))
    boundary = tf.cast(
        tf.concat([
            tf.expand_dims(
                [[tf.cast(max_feature_height, tf.float32)]] / level_strides - 1,
                axis=-1),
            tf.expand_dims(
                [[tf.cast(max_feature_width, tf.float32)]] / level_strides - 1,
                axis=-1),
        ],
                  axis=-1), boxes.dtype)

    # Compute grid positions.
    kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = compute_grid_positions(
        boxes, boundary, output_size, sample_offset=0.5)

    x_indices = tf.cast(
        tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
        dtype=tf.int32)
    y_indices = tf.cast(
        tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
        dtype=tf.int32)

    batch_size_offset = tf.tile(
        tf.reshape(
            tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1]),
        [1, num_boxes, output_size * 2, output_size * 2])
    # Get level offset for each box. Each box belongs to one level.
    levels_offset = tf.tile(
        tf.reshape(
            tf.gather(level_dim_offsets, levels),
            [batch_size, num_boxes, 1, 1]),
        [1, 1, output_size * 2, output_size * 2])
    y_indices_offset = tf.tile(
        tf.reshape(
            y_indices * tf.expand_dims(tf.gather(height_dim_sizes, levels), -1),
            [batch_size, num_boxes, output_size * 2, 1]),
        [1, 1, 1, output_size * 2])
    x_indices_offset = tf.tile(
        tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
        [1, 1, output_size * 2, 1])
    indices = tf.reshape(
        batch_size_offset + levels_offset + y_indices_offset + x_indices_offset,
        [-1])

    # TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
    # performance.
    features_per_box = tf.reshape(
        tf.gather(features_r2, indices),
        [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])

    # Bilinear interpolation.
    features_per_box = feature_bilinear_interpolation(features_per_box,
                                                      kernel_y, kernel_x)
    return features_per_box