in models/official/mask_rcnn/mask_rcnn_model.py [0:0]
def _model_fn(features, labels, mode, params, variable_filter_fn=None):
"""Model defination for the Mask-RCNN model based on ResNet.
Args:
features: the input image tensor and auxiliary information, such as
`image_info` and `source_ids`. The image tensor has a shape of
[batch_size, height, width, 3]. The height and width are fixed and equal.
labels: the input labels in a dictionary. The labels include score targets
and box targets which are dense label maps. The labels are generated from
get_input_fn function in data/dataloader.py
mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
params: the dictionary defines hyperparameters of model. The default
settings are in default_hparams function in this file.
variable_filter_fn: the filter function that takes trainable_variables and
returns the variable list after applying the filter rule.
Returns:
tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
"""
if (mode == tf.estimator.ModeKeys.PREDICT or
mode == tf.estimator.ModeKeys.EVAL):
if ((params['include_groundtruth_in_features'] or
mode == tf.estimator.ModeKeys.EVAL) and ('labels' in features)):
# In include groundtruth for eval.
labels = features['labels']
if 'features' in features:
features = features['features']
# Otherwise, it is in export mode, the features is past in directly.
if params['precision'] == 'bfloat16':
with tf.tpu.bfloat16_scope():
model_outputs = build_model_graph(features, labels,
mode == tf.estimator.ModeKeys.TRAIN,
params)
model_outputs.update({
'source_id': features['source_ids'],
'image_info': features['image_info'],
})
def cast_outputs_to_float(d):
for k, v in sorted(six.iteritems(d)):
if isinstance(v, dict):
cast_outputs_to_float(v)
else:
d[k] = tf.cast(v, tf.float32)
cast_outputs_to_float(model_outputs)
else:
model_outputs = build_model_graph(features, labels,
mode == tf.estimator.ModeKeys.TRAIN,
params)
model_outputs.update({
'source_id': features['source_ids'],
'image_info': features['image_info'],
})
# First check if it is in PREDICT or EVAL mode to fill out predictions.
# Predictions are used during the eval step to generate metrics.
predictions = {}
if (mode == tf.estimator.ModeKeys.PREDICT or
mode == tf.estimator.ModeKeys.EVAL):
if 'orig_images' in features:
model_outputs['orig_images'] = features['orig_images']
if labels and params['include_groundtruth_in_features']:
# Labels can only be embedded in predictions. The predition cannot output
# dictionary as a value.
predictions.update(labels)
model_outputs.pop('fpn_features', None)
predictions.update(model_outputs)
# If we are doing PREDICT, we can return here.
if mode == tf.estimator.ModeKeys.PREDICT:
if params['use_tpu']:
return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
predictions=predictions)
return tf.estimator.EstimatorSpec(mode=mode,
predictions=predictions)
# Set up training loss and learning rate.
global_step = tf.train.get_or_create_global_step()
if params['learning_rate_type'] == 'step':
learning_rate = learning_rates.step_learning_rate_with_linear_warmup(
global_step,
params['init_learning_rate'],
params['warmup_learning_rate'],
params['warmup_steps'],
params['learning_rate_levels'],
params['learning_rate_steps'])
elif params['learning_rate_type'] == 'cosine':
learning_rate = learning_rates.cosine_learning_rate_with_linear_warmup(
global_step,
params['init_learning_rate'],
params['warmup_learning_rate'],
params['warmup_steps'],
params['total_steps'])
else:
raise ValueError('Unsupported learning rate type: `{}`!'
.format(params['learning_rate_type']))
# score_loss and box_loss are for logging. only total_loss is optimized.
total_rpn_loss, rpn_score_loss, rpn_box_loss = losses.rpn_loss(
model_outputs['rpn_score_outputs'], model_outputs['rpn_box_outputs'],
labels, params)
(total_fast_rcnn_loss, fast_rcnn_class_loss,
fast_rcnn_box_loss) = losses.fast_rcnn_loss(
model_outputs['class_outputs'], model_outputs['box_outputs'],
model_outputs['class_targets'], model_outputs['box_targets'], params)
# Only training has the mask loss. Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/model_builder.py # pylint: disable=line-too-long
if mode == tf.estimator.ModeKeys.TRAIN and params['include_mask']:
mask_loss = losses.mask_rcnn_loss(
model_outputs['mask_outputs'], model_outputs['mask_targets'],
model_outputs['selected_class_targets'], params)
else:
mask_loss = 0.
if variable_filter_fn and ('resnet' in params['backbone']):
var_list = variable_filter_fn(tf.trainable_variables(),
params['backbone'] + '/')
else:
var_list = tf.trainable_variables()
l2_regularization_loss = params['l2_weight_decay'] * tf.add_n([
tf.nn.l2_loss(v)
for v in var_list
if 'batch_normalization' not in v.name and 'bias' not in v.name
])
total_loss = (total_rpn_loss + total_fast_rcnn_loss + mask_loss +
l2_regularization_loss)
host_call = None
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = create_optimizer(learning_rate, params)
if params['use_tpu']:
optimizer = tf.tpu.CrossShardOptimizer(optimizer)
scaffold_fn = None
if params['warm_start_path']:
def warm_start_scaffold_fn():
logging.info(
'model_fn warm start from: %s,', params['warm_start_path'])
assignment_map = _build_assigment_map(
optimizer,
prefix=None,
skip_variables_regex=params['skip_checkpoint_variables'])
tf.train.init_from_checkpoint(params['warm_start_path'], assignment_map)
return tf.train.Scaffold()
scaffold_fn = warm_start_scaffold_fn
elif params['checkpoint']:
def backbone_scaffold_fn():
"""Loads pretrained model through scaffold function."""
# Exclude all variable of optimizer.
vars_to_load = _build_assigment_map(
optimizer,
prefix=params['backbone'] + '/',
skip_variables_regex=params['skip_checkpoint_variables'])
tf.train.init_from_checkpoint(params['checkpoint'], vars_to_load)
if not vars_to_load:
raise ValueError('Variables to load is empty.')
return tf.train.Scaffold()
scaffold_fn = backbone_scaffold_fn
# Batch norm requires update_ops to be added as a train_op dependency.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
if params['global_gradient_clip_ratio'] > 0:
# Clips the gradients for training stability.
# Refer: https://arxiv.org/abs/1211.5063
with tf.name_scope('clipping'):
old_grads, variables = zip(*grads_and_vars)
num_weights = sum(
g.shape.num_elements() for g in old_grads if g is not None)
clip_norm = params['global_gradient_clip_ratio'] * math.sqrt(
num_weights)
logging.info(
'Global clip norm set to %g for %d variables with %d elements.',
clip_norm, sum(1 for g in old_grads if g is not None),
num_weights)
gradients, _ = tf.clip_by_global_norm(old_grads, clip_norm)
else:
gradients, variables = zip(*grads_and_vars)
grads_and_vars = []
# Special treatment for biases (beta is named as bias in reference model)
# Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/optimizer.py#L113 # pylint: disable=line-too-long
for grad, var in zip(gradients, variables):
if grad is not None and ('beta' in var.name or 'bias' in var.name):
grad = 2.0 * grad
grads_and_vars.append((grad, var))
with tf.control_dependencies(update_ops):
train_op = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)
if params['use_host_call']:
def host_call_fn(global_step, total_loss, total_rpn_loss, rpn_score_loss,
rpn_box_loss, total_fast_rcnn_loss, fast_rcnn_class_loss,
fast_rcnn_box_loss, mask_loss, l2_regularization_loss,
learning_rate):
"""Training host call. Creates scalar summaries for training metrics.
This function is executed on the CPU and should not directly reference
any Tensors in the rest of the `model_fn`. To pass Tensors from the
model to the `metric_fn`, provide as part of the `host_call`. See
https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
for more information.
Arguments should match the list of `Tensor` objects passed as the second
element in the tuple passed to `host_call`.
Args:
global_step: `Tensor with shape `[batch, ]` for the global_step.
total_loss: `Tensor` with shape `[batch, ]` for the training loss.
total_rpn_loss: `Tensor` with shape `[batch, ]` for the training RPN
loss.
rpn_score_loss: `Tensor` with shape `[batch, ]` for the training RPN
score loss.
rpn_box_loss: `Tensor` with shape `[batch, ]` for the training RPN
box loss.
total_fast_rcnn_loss: `Tensor` with shape `[batch, ]` for the
training Mask-RCNN loss.
fast_rcnn_class_loss: `Tensor` with shape `[batch, ]` for the
training Mask-RCNN class loss.
fast_rcnn_box_loss: `Tensor` with shape `[batch, ]` for the
training Mask-RCNN box loss.
mask_loss: `Tensor` with shape `[batch, ]` for the training Mask-RCNN
mask loss.
l2_regularization_loss: `Tensor` with shape `[batch, ]` for the
regularization loss.
learning_rate: `Tensor` with shape `[batch, ]` for the learning_rate.
Returns:
List of summary ops to run on the CPU host.
"""
# Outfeed supports int32 but global_step is expected to be int64.
global_step = tf.reduce_mean(global_step)
# Host call fns are executed FLAGS.iterations_per_loop times after one
# TPU loop is finished, setting max_queue value to the same as number of
# iterations will make the summary writer only flush the data to storage
# once per loop.
with (tf2.summary.create_file_writer(
params['model_dir'],
max_queue=params['iterations_per_loop']).as_default()):
with tf2.summary.record_if(True):
tf2.summary.scalar(
'total_loss', tf.reduce_mean(total_loss), step=global_step)
tf2.summary.scalar(
'total_rpn_loss', tf.reduce_mean(total_rpn_loss),
step=global_step)
tf2.summary.scalar(
'rpn_score_loss', tf.reduce_mean(rpn_score_loss),
step=global_step)
tf2.summary.scalar(
'rpn_box_loss', tf.reduce_mean(rpn_box_loss), step=global_step)
tf2.summary.scalar(
'total_fast_rcnn_loss', tf.reduce_mean(total_fast_rcnn_loss),
step=global_step)
tf2.summary.scalar(
'fast_rcnn_class_loss', tf.reduce_mean(fast_rcnn_class_loss),
step=global_step)
tf2.summary.scalar(
'fast_rcnn_box_loss', tf.reduce_mean(fast_rcnn_box_loss),
step=global_step)
if params['include_mask']:
tf2.summary.scalar(
'mask_loss', tf.reduce_mean(mask_loss), step=global_step)
tf2.summary.scalar(
'l2_regularization_loss',
tf.reduce_mean(l2_regularization_loss),
step=global_step)
tf2.summary.scalar(
'learning_rate', tf.reduce_mean(learning_rate),
step=global_step)
return tf.summary.all_v2_summary_ops()
# To log the loss, current learning rate, and epoch for Tensorboard, the
# summary op needs to be run on the host CPU via host_call. host_call
# expects [batch_size, ...] Tensors, thus reshape to introduce a batch
# dimension. These Tensors are implicitly concatenated to
# [params['batch_size']].
global_step_t = tf.reshape(global_step, [1])
total_loss_t = tf.reshape(total_loss, [1])
total_rpn_loss_t = tf.reshape(total_rpn_loss, [1])
rpn_score_loss_t = tf.reshape(rpn_score_loss, [1])
rpn_box_loss_t = tf.reshape(rpn_box_loss, [1])
total_fast_rcnn_loss_t = tf.reshape(total_fast_rcnn_loss, [1])
fast_rcnn_class_loss_t = tf.reshape(fast_rcnn_class_loss, [1])
fast_rcnn_box_loss_t = tf.reshape(fast_rcnn_box_loss, [1])
mask_loss_t = tf.reshape(mask_loss, [1])
l2_regularization_loss = tf.reshape(l2_regularization_loss, [1])
learning_rate_t = tf.reshape(learning_rate, [1])
host_call = (host_call_fn,
[global_step_t, total_loss_t, total_rpn_loss_t,
rpn_score_loss_t, rpn_box_loss_t, total_fast_rcnn_loss_t,
fast_rcnn_class_loss_t, fast_rcnn_box_loss_t,
mask_loss_t, l2_regularization_loss, learning_rate_t])
else:
train_op = None
scaffold_fn = None
if params['use_tpu']:
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
host_call=host_call,
scaffold_fn=scaffold_fn)
return tf.estimator.EstimatorSpec(
mode=mode, loss=total_loss, train_op=train_op)