in tensorflow/sagemakercv/training/losses/frcnn_losses.py [0:0]
def __call__(self,
class_outputs,
box_outputs,
class_targets,
box_targets,
rpn_box_rois,
image_info):
with tf.name_scope('fast_rcnn_loss'):
class_targets = tf.cast(class_targets, dtype=tf.int32)
mask_targets = class_targets
if self.class_agnostic_box:
mask_targets = tf.clip_by_value(class_targets, 0, 1)
# Selects the box from `box_outputs` based on `class_targets`, with which
# the box has the maximum overlap.
batch_size, num_rois, _ = box_outputs.get_shape().as_list()
box_num_classes = 2 if self.class_agnostic_box else self.num_classes
box_outputs = tf.reshape(box_outputs, [batch_size, num_rois, box_num_classes, 4])
box_indices = tf.reshape(
mask_targets +
tf.tile(tf.expand_dims(tf.range(batch_size) * num_rois * \
box_num_classes, 1), [1, num_rois]) +
tf.tile(tf.expand_dims(tf.range(num_rois) * box_num_classes, 0), [batch_size, 1]),
[-1]
)
box_outputs = tf.matmul(
tf.one_hot(
box_indices,
batch_size * num_rois * box_num_classes,
dtype=box_outputs.dtype
),
tf.reshape(box_outputs, [-1, 4])
)
if self.box_loss_type in ['giou', 'ciou']:
# decode outputs to move deltas back to coordinate space
rpn_box_rois = tf.reshape(rpn_box_rois, [-1, 4])
box_outputs = box_utils.decode_boxes(encoded_boxes=box_outputs,
anchors=rpn_box_rois,
weights=self.bbox_reg_weights)
box_outputs = box_utils.clip_boxes(box_outputs, self.image_size)
box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
if not self.use_carl_loss:
box_loss = self._fast_rcnn_box_loss(
box_outputs=box_outputs,
box_targets=box_targets,
class_targets=mask_targets,
loss_type=self.box_loss_type,
normalizer=1.0
)
else:
if self.class_agnostic_box:
raise NotImplementedError
box_loss = self._fast_rcnn_box_carl_loss(
box_outputs=box_outputs,
box_targets=box_targets,
class_targets=mask_targets,
class_outputs=class_outputs,
loss_type=self.box_loss_type,
normalizer=2.0
)
box_loss *= self.fast_rcnn_box_loss_weight
use_sparse_x_entropy = False
_class_targets = class_targets \
if use_sparse_x_entropy \
else tf.one_hot(class_targets, self.num_classes)
class_loss = self._fast_rcnn_class_loss(
class_outputs=class_outputs,
class_targets_one_hot=_class_targets,
normalizer=1.0
)
total_loss = class_loss + box_loss
return total_loss, class_loss, box_loss