def resnet_generator()

in models/official/resnet/resnet_model.py [0:0]


def resnet_generator(block_fn,
                     layers,
                     num_classes,
                     data_format='channels_first',
                     use_resnetd_stem=False,
                     resnetd_shortcut=False,
                     replace_stem_max_pool=False,
                     skip_stem_max_pool=False,
                     drop_connect_rate=None,
                     se_ratio=None,
                     dropout_rate=None,
                     dropblock_keep_probs=None,
                     dropblock_size=None,
                     pre_activation=False,
                     norm_act_layer=LAYER_BN_RELU,
                     bn_momentum=MOVING_AVERAGE_DECAY):
  """Generator for ResNet models.

  Args:
    block_fn: `function` for the block to use within the model. Either
        `residual_block` or `bottleneck_block`.
    layers: list of 4 `int`s denoting the number of blocks to include in each
      of the 4 block groups. Each group consists of blocks that take inputs of
      the same resolution.
    num_classes: `int` number of possible classes for image classification.
    data_format: `str` either "channels_first" for `[batch, channels, height,
        width]` or "channels_last for `[batch, height, width, channels]`.
    use_resnetd_stem: `bool` whether to use ResNet-D stem.
    resnetd_shortcut: `bool` whether to use ResNet-D shortcut in blocks.
    replace_stem_max_pool: `bool` if True, replace the max pool in stem with
        a stride-2 conv,
    skip_stem_max_pool: `bool` if True, skip the max pool in stem and set the
        stride of the following block to 2,
    drop_connect_rate: `float` initial rate for drop-connect.
    se_ratio: `float` Squeeze-and-Excitation ratio for SE layers.
    dropout_rate: `float` drop rate for the dropout layer.
    dropblock_keep_probs: `list` of 4 elements denoting keep_prob of DropBlock
      for each block group. None indicates no DropBlock for the corresponding
      block group.
    dropblock_size: `int`: size parameter of DropBlock.
    pre_activation: whether to use pre-activation ResNet (ResNet-v2).
    norm_act_layer: name of the normalization-activation layer.
    bn_momentum: `float` momentum for batch norm layer.

  Returns:
    Model `function` that takes in `inputs` and `is_training` and returns the
    output `Tensor` of the ResNet model.

  Raises:
    if dropblock_keep_probs is not 'None' or a list with len 4.
  """
  if dropblock_keep_probs is None:
    dropblock_keep_probs = [None] * 4
  if not isinstance(dropblock_keep_probs,
                    list) or len(dropblock_keep_probs) != 4:
    raise ValueError('dropblock_keep_probs is not valid:', dropblock_keep_probs)

  def model(inputs, is_training):
    """Creation of the model graph."""
    if use_resnetd_stem:
      inputs = conv2d_fixed_padding(
          inputs=inputs, filters=32, kernel_size=3, strides=2,
          data_format=data_format)
      inputs = norm_activation(
          inputs, is_training, data_format=data_format,
          layer=norm_act_layer, bn_momentum=bn_momentum)
      inputs = conv2d_fixed_padding(
          inputs=inputs, filters=32, kernel_size=3, strides=1,
          data_format=data_format)
      inputs = norm_activation(
          inputs, is_training, data_format=data_format,
          layer=norm_act_layer, bn_momentum=bn_momentum)
      inputs = conv2d_fixed_padding(
          inputs=inputs, filters=64, kernel_size=3, strides=1,
          data_format=data_format)
    else:
      inputs = conv2d_fixed_padding(
          inputs=inputs, filters=64, kernel_size=7, strides=2,
          data_format=data_format)

    inputs = tf.identity(inputs, 'initial_conv')
    if not pre_activation:
      inputs = norm_activation(inputs, is_training, data_format=data_format,
                               layer=norm_act_layer, bn_momentum=bn_momentum)

    if not skip_stem_max_pool:
      if replace_stem_max_pool:
        inputs = conv2d_fixed_padding(
            inputs=inputs, filters=64,
            kernel_size=3, strides=2, data_format=data_format)
        inputs = norm_activation(
            inputs, is_training, data_format=data_format,
            bn_momentum=bn_momentum)
      else:
        inputs = tf.layers.max_pooling2d(
            inputs=inputs, pool_size=3, strides=2, padding='SAME',
            data_format=data_format)
        inputs = tf.identity(inputs, 'initial_max_pool')

    custom_block_group = functools.partial(
        block_group,
        data_format=data_format,
        dropblock_size=dropblock_size,
        pre_activation=pre_activation,
        norm_act_layer=norm_act_layer,
        se_ratio=se_ratio,
        resnetd_shortcut=resnetd_shortcut,
        bn_momentum=bn_momentum)

    num_layers = len(layers) + 1
    stride_c2 = 2 if skip_stem_max_pool else 1

    inputs = custom_block_group(
        inputs=inputs, filters=64, block_fn=block_fn, blocks=layers[0],
        strides=stride_c2, is_training=is_training, name='block_group1',
        dropblock_keep_prob=dropblock_keep_probs[0],
        drop_connect_rate=resnet_layers.get_drop_connect_rate(
            drop_connect_rate, 2, num_layers))
    inputs = custom_block_group(
        inputs=inputs, filters=128, block_fn=block_fn, blocks=layers[1],
        strides=2, is_training=is_training, name='block_group2',
        dropblock_keep_prob=dropblock_keep_probs[1],
        drop_connect_rate=resnet_layers.get_drop_connect_rate(
            drop_connect_rate, 3, num_layers))
    inputs = custom_block_group(
        inputs=inputs, filters=256, block_fn=block_fn, blocks=layers[2],
        strides=2, is_training=is_training, name='block_group3',
        dropblock_keep_prob=dropblock_keep_probs[2],
        drop_connect_rate=resnet_layers.get_drop_connect_rate(
            drop_connect_rate, 4, num_layers))
    inputs = custom_block_group(
        inputs=inputs, filters=512, block_fn=block_fn, blocks=layers[3],
        strides=2, is_training=is_training, name='block_group4',
        dropblock_keep_prob=dropblock_keep_probs[3],
        drop_connect_rate=resnet_layers.get_drop_connect_rate(
            drop_connect_rate, 5, num_layers))

    if pre_activation:
      inputs = norm_activation(inputs, is_training, data_format=data_format,
                               layer=norm_act_layer, bn_momentum=bn_momentum)

    # The activation is 7x7 so this is a global average pool.
    # TODO(huangyp): reduce_mean will be faster.
    if data_format == 'channels_last':
      pool_size = (inputs.shape[1], inputs.shape[2])
    else:
      pool_size = (inputs.shape[2], inputs.shape[3])
    inputs = tf.layers.average_pooling2d(
        inputs=inputs, pool_size=pool_size, strides=1, padding='VALID',
        data_format=data_format)
    inputs = tf.identity(inputs, 'final_avg_pool')
    inputs = tf.reshape(
        inputs, [-1, 2048 if block_fn is bottleneck_block else 512])

    if dropout_rate is not None:
      tf.logging.info('using dropout')
      inputs = tf.layers.dropout(
          inputs, rate=dropout_rate, training=is_training)

    inputs = tf.layers.dense(
        inputs=inputs,
        units=num_classes,
        kernel_initializer=tf.random_normal_initializer(stddev=.01))
    inputs = tf.identity(inputs, 'final_dense')
    return inputs

  model.default_image_size = 224
  return model