in models/official/retinanet/retinanet_model.py [0:0]
def _model_fn(features, labels, mode, params, model, use_tpu_estimator_spec,
variable_filter_fn=None):
"""Model defination for the RetinaNet model based on ResNet.
Args:
features: the input image tensor with shape [batch_size, height, width, 3].
The height and width are fixed and equal.
labels: the input labels in a dictionary. The labels include class targets
and box targets which are dense label maps. The labels are generated from
get_input_fn function in dataloader.py
mode: the mode of TPUEstimator/Estimator including TRAIN, EVAL, and PREDICT.
params: the dictionary defines hyperparameters of model. The default
settings are in default_hparams function in this file.
model: the RetinaNet model outputs class logits and box regression outputs.
use_tpu_estimator_spec: Whether to use TPUEstimatorSpec or EstimatorSpec.
variable_filter_fn: the filter function that takes trainable_variables and
returns the variable list after applying the filter rule.
Returns:
tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.
"""
# In predict mode features is a dict with input as value of the 'inputs'.
image_info = None
if (mode == tf.estimator.ModeKeys.PREDICT
and isinstance(features, dict) and 'inputs' in features):
image_info = features['image_info']
labels = None
if 'labels' in features:
labels = features['labels']
features = features['inputs']
def _model_outputs():
return model(
features,
min_level=params['min_level'],
max_level=params['max_level'],
num_classes=params['num_classes'],
num_anchors=len(params['aspect_ratios'] * params['num_scales']),
resnet_depth=params['resnet_depth'],
is_training_bn=params['is_training_bn'])
if params['use_bfloat16']:
with contrib_tpu.bfloat16_scope():
cls_outputs, box_outputs = _model_outputs()
levels = cls_outputs.keys()
for level in levels:
cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
else:
cls_outputs, box_outputs = _model_outputs()
levels = cls_outputs.keys()
# First check if it is in PREDICT mode.
if mode == tf.estimator.ModeKeys.PREDICT:
# Postprocess on host; memory layout for NMS on TPU is very inefficient.
def _predict_postprocess_wrapper(args):
return _predict_postprocess(*args)
predictions = contrib_tpu.outside_compilation(
_predict_postprocess_wrapper,
(cls_outputs, box_outputs, labels, params))
# Include resizing information on prediction output to help bbox drawing.
if image_info is not None:
predictions.update({
'image_info': tf.identity(image_info, 'ImageInfo'),
})
return contrib_tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.PREDICT, predictions=predictions)
# Load pretrained model from checkpoint.
if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:
def scaffold_fn():
"""Loads pretrained model through scaffold function."""
tf.train.init_from_checkpoint(params['resnet_checkpoint'], {
'/': 'resnet%s/' % params['resnet_depth'],
})
return tf.train.Scaffold()
else:
scaffold_fn = None
# Set up training loss and learning rate.
update_learning_rate_schedule_parameters(params)
global_step = tf.train.get_global_step()
learning_rate = learning_rate_schedule(
params['adjusted_learning_rate'], params['lr_warmup_init'],
params['lr_warmup_step'], params['first_lr_drop_step'],
params['second_lr_drop_step'], global_step)
# cls_loss and box_loss are for logging. only total_loss is optimized.
total_loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs,
labels, params)
total_loss += _WEIGHT_DECAY * tf.add_n([
tf.nn.l2_loss(v)
for v in tf.trainable_variables()
if 'batch_normalization' not in v.name
])
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.MomentumOptimizer(
learning_rate, momentum=params['momentum'])
if params['use_tpu']:
optimizer = contrib_tpu.CrossShardOptimizer(optimizer)
else:
if params['auto_mixed_precision']:
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer)
# Batch norm requires `update_ops` to be executed alongside `train_op`.
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
var_list = variable_filter_fn(
tf.trainable_variables(),
params['resnet_depth']) if variable_filter_fn else None
minimize_op = optimizer.minimize(total_loss, global_step, var_list=var_list)
train_op = tf.group(minimize_op, update_ops)
else:
train_op = None
eval_metrics = None
if mode == tf.estimator.ModeKeys.EVAL:
def metric_fn(**kwargs):
"""Returns a dictionary that has the evaluation metrics."""
batch_size = params['batch_size']
eval_anchors = anchors.Anchors(
params['min_level'], params['max_level'], params['num_scales'],
params['aspect_ratios'], params['anchor_scale'], params['image_size'])
anchor_labeler = anchors.AnchorLabeler(eval_anchors,
params['num_classes'])
cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
coco_metrics = coco_metric_fn(batch_size, anchor_labeler,
params['val_json_file'], **kwargs)
# Add metrics to output.
output_metrics = {
'cls_loss': cls_loss,
'box_loss': box_loss,
}
output_metrics.update(coco_metrics)
return output_metrics
cls_loss_repeat = tf.reshape(
tf.tile(tf.expand_dims(cls_loss, 0), [
params['batch_size'],
]), [params['batch_size'], 1])
box_loss_repeat = tf.reshape(
tf.tile(tf.expand_dims(box_loss, 0), [
params['batch_size'],
]), [params['batch_size'], 1])
metric_fn_inputs = {
'cls_loss_repeat': cls_loss_repeat,
'box_loss_repeat': box_loss_repeat,
'source_ids': labels['source_ids'],
'groundtruth_data': labels['groundtruth_data'],
'image_scales': labels['image_scales'],
}
add_metric_fn_inputs(params, cls_outputs, box_outputs, metric_fn_inputs)
eval_metrics = (metric_fn, metric_fn_inputs)
if use_tpu_estimator_spec:
return contrib_tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
eval_metrics=eval_metrics,
scaffold_fn=scaffold_fn)
else:
return tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
# TODO(rostam): Fix bug to get scaffold working.
# scaffold=scaffold_fn(),
train_op=train_op)