def sequence_sampled_softmax_cross_entropy()

in ludwig/modules/loss_modules.py [0:0]


def sequence_sampled_softmax_cross_entropy(targets, targets_sequence_length,
                                           eval_logits, train_logits,
                                           class_weights,
                                           class_biases, loss,
                                           num_classes):
    batch_max_targets_sequence_length = tf.shape(targets)[1]

    batch_max_train_logits_sequence_length = tf.shape(train_logits)[1]
    difference_train = batch_max_targets_sequence_length - batch_max_train_logits_sequence_length
    padded_train_logits = tf.pad(train_logits,
                                 [[0, 0], [0, difference_train], [0, 0]])

    batch_max_eval_logits_sequence_length = tf.shape(eval_logits)[1]
    difference_eval = batch_max_targets_sequence_length - batch_max_eval_logits_sequence_length
    padded_eval_logits = tf.pad(eval_logits,
                                [[0, 0], [0, difference_eval], [0, 0]])

    # batch_max_seq_length = tf.shape(train_logits)[1]
    # unpadded_targets = targets[:, :batch_max_seq_length]
    # output_exp = tf.cast(tf.reshape(unpadded_targets, [-1, 1]), tf.int64)
    output_exp = tf.cast(tf.reshape(targets, [-1, 1]), tf.int64)
    sampled_values = sample_values_from_classes(output_exp, loss['sampler'],
                                                num_classes,
                                                loss['negative_samples'],
                                                loss['unique'],
                                                loss['class_counts'],
                                                loss['distortion'])

    def _sampled_loss(labels, logits):
        labels = tf.cast(labels, tf.int64)
        labels = tf.reshape(labels, [-1, 1])
        logits = tf.cast(logits, tf.float32)

        return tf.cast(
            tf.nn.sampled_softmax_loss(weights=tf.transpose(class_weights),
                                       biases=class_biases,
                                       labels=labels,
                                       inputs=logits,
                                       num_sampled=loss['negative_samples'],
                                       num_classes=num_classes,
                                       sampled_values=sampled_values),
            tf.float32)

    train_loss = tfa.seq2seq.sequence_loss(
        padded_train_logits,
        targets,
        tf.sequence_mask(targets_sequence_length,
                         batch_max_targets_sequence_length, dtype=tf.float32),
        average_across_timesteps=True,
        average_across_batch=False,
        softmax_loss_function=_sampled_loss
    )

    # batch_max_seq_length_eval = tf.shape(eval_logits)[1]
    # unpadded_targets_eval = targets[:, :batch_max_seq_length_eval]

    eval_loss = tfa.seq2seq.sequence_loss(
        padded_eval_logits,
        targets,
        tf.sequence_mask(targets_sequence_length,
                         batch_max_targets_sequence_length, dtype=tf.float32),
        average_across_timesteps=True,
        average_across_batch=False
    )

    return train_loss, eval_loss