in easy_rec/python/builders/loss_builder.py [0:0]
def build(loss_type,
label,
pred,
loss_weight=1.0,
num_class=1,
loss_param=None,
**kwargs):
loss_name = kwargs.pop('loss_name') if 'loss_name' in kwargs else 'unknown'
if loss_type == LossType.CLASSIFICATION:
if num_class == 1:
return tf.losses.sigmoid_cross_entropy(
label, logits=pred, weights=loss_weight, **kwargs)
else:
assert label.dtype in [tf.int32, tf.int64], \
'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.'
return tf.losses.sparse_softmax_cross_entropy(
labels=label, logits=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.CROSS_ENTROPY_LOSS:
return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.BINARY_CROSS_ENTROPY_LOSS:
losses = tf.keras.backend.binary_crossentropy(label, pred, from_logits=True)
return tf.reduce_mean(losses)
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
logging.info('%s is used' % LossType.Name(loss_type))
return tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.ZILN_LOSS:
loss = zero_inflated_lognormal_loss(label, pred)
if np.isscalar(loss_weight) and loss_weight != 1.0:
return loss * loss_weight
return loss
elif loss_type == LossType.JRC_LOSS:
session = kwargs.get('session_ids', None)
if loss_param is None:
return jrc_loss(label, pred, session, name=loss_name)
return jrc_loss(
label,
pred,
session,
loss_param.alpha,
loss_weight_strategy=loss_param.loss_weight_strategy,
sample_weights=loss_weight,
same_label_loss=loss_param.same_label_loss,
name=loss_name)
elif loss_type == LossType.PAIR_WISE_LOSS:
session = kwargs.get('session_ids', None)
margin = 0 if loss_param is None else loss_param.margin
temp = 1.0 if loss_param is None else loss_param.temperature
return pairwise_loss(
label,
pred,
session_ids=session,
margin=margin,
temperature=temp,
weights=loss_weight,
name=loss_name)
elif loss_type == LossType.PAIRWISE_LOGISTIC_LOSS:
session = kwargs.get('session_ids', None)
temp = 1.0 if loss_param is None else loss_param.temperature
ohem_ratio = 1.0 if loss_param is None else loss_param.ohem_ratio
hinge_margin = None
if loss_param is not None and loss_param.HasField('hinge_margin'):
hinge_margin = loss_param.hinge_margin
lbl_margin = False if loss_param is None else loss_param.use_label_margin
return pairwise_logistic_loss(
label,
pred,
session_ids=session,
temperature=temp,
hinge_margin=hinge_margin,
ohem_ratio=ohem_ratio,
weights=loss_weight,
use_label_margin=lbl_margin,
name=loss_name)
elif loss_type == LossType.PAIRWISE_HINGE_LOSS:
session = kwargs.get('session_ids', None)
temp, ohem_ratio, margin = 1.0, 1.0, 1.0
label_is_logits, use_label_margin, use_exponent = True, True, False
if loss_param is not None:
temp = loss_param.temperature
ohem_ratio = loss_param.ohem_ratio
margin = loss_param.margin
label_is_logits = loss_param.label_is_logits
use_label_margin = loss_param.use_label_margin
use_exponent = loss_param.use_exponent
return pairwise_hinge_loss(
label,
pred,
session_ids=session,
temperature=temp,
margin=margin,
ohem_ratio=ohem_ratio,
weights=loss_weight,
label_is_logits=label_is_logits,
use_label_margin=use_label_margin,
use_exponent=use_exponent,
name=loss_name)
elif loss_type == LossType.PAIRWISE_FOCAL_LOSS:
session = kwargs.get('session_ids', None)
if loss_param is None:
return pairwise_focal_loss(
label, pred, session_ids=session, weights=loss_weight, name=loss_name)
hinge_margin = None
if loss_param.HasField('hinge_margin'):
hinge_margin = loss_param.hinge_margin
return pairwise_focal_loss(
label,
pred,
session_ids=session,
gamma=loss_param.gamma,
alpha=loss_param.alpha if loss_param.HasField('alpha') else None,
hinge_margin=hinge_margin,
ohem_ratio=loss_param.ohem_ratio,
temperature=loss_param.temperature,
weights=loss_weight,
name=loss_name)
elif loss_type == LossType.LISTWISE_RANK_LOSS:
session = kwargs.get('session_ids', None)
trans_fn, temp, label_is_logits, scale = None, 1.0, False, False
if loss_param is not None:
temp = loss_param.temperature
label_is_logits = loss_param.label_is_logits
scale = loss_param.scale_logits
if loss_param.HasField('transform_fn'):
trans_fn = loss_param.transform_fn
return listwise_rank_loss(
label,
pred,
session,
temperature=temp,
label_is_logits=label_is_logits,
transform_fn=trans_fn,
scale_logits=scale,
weights=loss_weight)
elif loss_type == LossType.LISTWISE_DISTILL_LOSS:
session = kwargs.get('session_ids', None)
trans_fn, temp, label_clip_max_value, scale = None, 1.0, 512.0, False
if loss_param is not None:
temp = loss_param.temperature
label_clip_max_value = loss_param.label_clip_max_value
scale = loss_param.scale_logits
if loss_param.HasField('transform_fn'):
trans_fn = loss_param.transform_fn
return listwise_distill_loss(
label,
pred,
session,
temperature=temp,
label_clip_max_value=label_clip_max_value,
transform_fn=trans_fn,
scale_logits=scale,
weights=loss_weight)
elif loss_type == LossType.F1_REWEIGHTED_LOSS:
f1_beta_square = 1.0 if loss_param is None else loss_param.f1_beta_square
label_smoothing = 0 if loss_param is None else loss_param.label_smoothing
return f1_reweight_sigmoid_cross_entropy(
label,
pred,
f1_beta_square,
weights=loss_weight,
label_smoothing=label_smoothing)
elif loss_type == LossType.BINARY_FOCAL_LOSS:
if loss_param is None:
return sigmoid_focal_loss_with_logits(
label, pred, sample_weights=loss_weight, name=loss_name)
gamma = loss_param.gamma
alpha = None
if loss_param.HasField('alpha'):
alpha = loss_param.alpha
return sigmoid_focal_loss_with_logits(
label,
pred,
gamma=gamma,
alpha=alpha,
ohem_ratio=loss_param.ohem_ratio,
sample_weights=loss_weight,
label_smoothing=loss_param.label_smoothing,
name=loss_name)
else:
raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))