in easy_rec/python/builders/loss_builder.py [0:0]
def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
"""Build knowledge distillation loss.
Args:
kds: list of knowledge distillation object of type KD.
prediction_dict: dict of predict_name to predict tensors.
label_dict: ordered dict of label_name to label tensors.
feature_dict: dict of feature name to feature value
Return:
knowledge distillation loss will be add to loss_dict with key: kd_loss.
"""
loss_dict = {}
for kd in kds:
assert kd.pred_name in prediction_dict, \
'invalid predict_name: %s available ones: %s' % (
kd.pred_name, ','.join(prediction_dict.keys()))
loss_name = kd.loss_name
if not loss_name:
loss_name = 'kd_loss_' + kd.pred_name.replace('/', '_')
loss_name += '_' + kd.soft_label_name.replace('/', '_')
loss_weight = kd.loss_weight
if kd.HasField('task_space_indicator_name') and kd.HasField(
'task_space_indicator_value'):
in_task_space = tf.to_float(
tf.equal(feature_dict[kd.task_space_indicator_name],
kd.task_space_indicator_value))
loss_weight = loss_weight * (
kd.in_task_space_weight * in_task_space + kd.out_task_space_weight *
(1 - in_task_space))
label = label_dict[kd.soft_label_name]
pred = prediction_dict[kd.pred_name]
epsilon = tf.keras.backend.epsilon()
num_class = 1 if len(pred.get_shape()) < 2 else pred.get_shape()[-1]
if kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
if not kd.label_is_logits: # label is prob
label = tf.clip_by_value(label, epsilon, 1 - epsilon)
label = tf.log(label / (1 - label))
if not kd.pred_is_logits:
pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
pred = tf.log(pred / (1 - pred))
if kd.temperature > 0:
label = label / kd.temperature
pred = pred / kd.temperature
label = tf.nn.sigmoid(label) # convert to prob
elif kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
if not kd.label_is_logits: # label is prob
if num_class == 1: # for binary classification
label = tf.clip_by_value(label, epsilon, 1 - epsilon)
label = tf.log(label / (1 - label))
else:
label = tf.math.log(label + epsilon)
label -= tf.reduce_max(label)
if not kd.pred_is_logits:
if num_class == 1: # for binary classification
pred = tf.clip_by_value(pred, epsilon, 1 - epsilon)
pred = tf.log(pred / (1 - pred))
else:
pred = tf.math.log(pred + epsilon)
pred -= tf.reduce_max(pred)
if kd.temperature > 0:
label = label / kd.temperature
pred = pred / kd.temperature
if num_class > 1:
label = tf.nn.softmax(label)
pred = tf.nn.softmax(pred)
else:
label = tf.nn.sigmoid(label) # convert to prob
pred = tf.nn.sigmoid(pred) # convert to prob
elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
if not kd.label_is_logits:
label = tf.math.log(label + epsilon)
if not kd.pred_is_logits:
pred = tf.math.log(pred + epsilon)
if kd.temperature > 0:
label = label / kd.temperature
pred = pred / kd.temperature
if num_class > 1:
label = tf.nn.softmax(label)
pred = tf.nn.softmax(pred)
elif num_class == 1:
label = tf.nn.sigmoid(label)
pred = tf.nn.sigmoid(pred)
if kd.loss_type == LossType.KL_DIVERGENCE_LOSS:
if num_class == 1:
label = tf.expand_dims(label, 1) # [B, 1]
labels = tf.concat([1 - label, label], axis=1) # [B, 2]
pred = tf.expand_dims(pred, 1) # [B, 1]
preds = tf.concat([1 - pred, pred], axis=1) # [B, 2]
else:
labels = label
preds = pred
losses = tf.keras.losses.KLD(labels, preds)
loss_dict[loss_name] = tf.reduce_mean(
losses, name=loss_name) * loss_weight
elif kd.loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
losses = tf.keras.backend.binary_crossentropy(
label, pred, from_logits=True)
loss_dict[loss_name] = tf.reduce_mean(
losses, name=loss_name) * loss_weight
elif kd.loss_type == LossType.CROSS_ENTROPY_LOSS:
loss_dict[loss_name] = tf.losses.log_loss(
label, pred, weights=loss_weight)
elif kd.loss_type == LossType.L2_LOSS:
loss_dict[loss_name] = tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight)
else:
loss_param = kd.WhichOneof('loss_param')
kwargs = {}
if loss_param is not None:
loss_param = getattr(kd, loss_param)
if hasattr(loss_param, 'session_name'):
kwargs['session_ids'] = feature_dict[loss_param.session_name]
loss_dict[loss_name] = build(
kd.loss_type,
label,
pred,
loss_weight=loss_weight,
loss_param=loss_param,
**kwargs)
return loss_dict