easy_rec/python/loss/f1_reweight_loss.py (31 lines of code) (raw):
# coding=utf-8
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
def f1_reweight_sigmoid_cross_entropy(labels,
logits,
beta_square,
label_smoothing=0,
weights=None):
"""Refer paper: Adaptive Scaling for Sparse Detection in Information Extraction."""
probs = tf.nn.sigmoid(logits)
if len(logits.shape.as_list()) == 1:
logits = tf.expand_dims(logits, -1)
if len(labels.shape.as_list()) == 1:
labels = tf.expand_dims(labels, -1)
labels = tf.to_float(labels)
batch_size = tf.shape(labels)[0]
batch_size_float = tf.to_float(batch_size)
num_pos = tf.reduce_sum(labels, axis=0)
num_neg = batch_size_float - num_pos
tp = tf.reduce_sum(probs, axis=0)
tn = batch_size_float - tp
neg_weight = tp / (beta_square * num_pos + num_neg - tn + 1e-8)
neg_weight_tile = tf.tile(tf.expand_dims(neg_weight, 0), [batch_size, 1])
final_weights = tf.where(
tf.equal(labels, 1.0), tf.ones_like(labels), neg_weight_tile)
if weights is not None:
weights = tf.cast(weights, tf.float32)
if len(weights.shape.as_list()) == 1:
weights = tf.expand_dims(weights, -1)
final_weights *= weights
return tf.losses.sigmoid_cross_entropy(
labels, logits, final_weights, label_smoothing=label_smoothing)