in tensorflow_addons/activations/sparsemax.py [0:0]
def _compute_2d_sparsemax(logits):
"""Performs the sparsemax operation when axis=-1."""
shape_op = tf.shape(logits)
obs = tf.math.reduce_prod(shape_op[:-1])
dims = shape_op[-1]
# In the paper, they call the logits z.
# The mean(logits) can be substracted from logits to make the algorithm
# more numerically stable. the instability in this algorithm comes mostly
# from the z_cumsum. Substacting the mean will cause z_cumsum to be close
# to zero. However, in practise the numerical instability issues are very
# minor and substacting the mean causes extra issues with inf and nan
# input.
# Reshape to [obs, dims] as it is almost free and means the remanining
# code doesn't need to worry about the rank.
z = tf.reshape(logits, [obs, dims])
# sort z
z_sorted, _ = tf.nn.top_k(z, k=dims)
# calculate k(z)
z_cumsum = tf.math.cumsum(z_sorted, axis=-1)
k = tf.range(1, tf.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
z_check = 1 + k * z_sorted > z_cumsum
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
# (index + 1) of the last `1` is the same as just summing the number of 1.
k_z = tf.math.reduce_sum(tf.cast(z_check, tf.int32), axis=-1)
# calculate tau(z)
# If there are inf values or all values are -inf, the k_z will be zero,
# this is mathematically invalid and will also cause the gather_nd to fail.
# Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
# fixed later (see p_safe) by returning p = nan. This results in the same
# behavior as softmax.
k_z_safe = tf.math.maximum(k_z, 1)
indices = tf.stack([tf.range(0, obs), tf.reshape(k_z_safe, [-1]) - 1], axis=1)
tau_sum = tf.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / tf.cast(k_z, logits.dtype)
# calculate p
p = tf.math.maximum(tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
# If k_z = 0 or if z = nan, then the input is invalid
p_safe = tf.where(
tf.expand_dims(
tf.math.logical_or(tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])),
axis=-1,
),
tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)),
p,
)
# Reshape back to original size
p_safe = tf.reshape(p_safe, shape_op)
return p_safe