in ludwig/modules/loss_modules.py [0:0]
def sequence_sampled_softmax_cross_entropy(targets, targets_sequence_length,
eval_logits, train_logits,
class_weights,
class_biases, loss,
num_classes):
batch_max_targets_sequence_length = tf.shape(targets)[1]
batch_max_train_logits_sequence_length = tf.shape(train_logits)[1]
difference_train = batch_max_targets_sequence_length - batch_max_train_logits_sequence_length
padded_train_logits = tf.pad(train_logits,
[[0, 0], [0, difference_train], [0, 0]])
batch_max_eval_logits_sequence_length = tf.shape(eval_logits)[1]
difference_eval = batch_max_targets_sequence_length - batch_max_eval_logits_sequence_length
padded_eval_logits = tf.pad(eval_logits,
[[0, 0], [0, difference_eval], [0, 0]])
# batch_max_seq_length = tf.shape(train_logits)[1]
# unpadded_targets = targets[:, :batch_max_seq_length]
# output_exp = tf.cast(tf.reshape(unpadded_targets, [-1, 1]), tf.int64)
output_exp = tf.cast(tf.reshape(targets, [-1, 1]), tf.int64)
sampled_values = sample_values_from_classes(output_exp, loss['sampler'],
num_classes,
loss['negative_samples'],
loss['unique'],
loss['class_counts'],
loss['distortion'])
def _sampled_loss(labels, logits):
labels = tf.cast(labels, tf.int64)
labels = tf.reshape(labels, [-1, 1])
logits = tf.cast(logits, tf.float32)
return tf.cast(
tf.nn.sampled_softmax_loss(weights=tf.transpose(class_weights),
biases=class_biases,
labels=labels,
inputs=logits,
num_sampled=loss['negative_samples'],
num_classes=num_classes,
sampled_values=sampled_values),
tf.float32)
train_loss = tfa.seq2seq.sequence_loss(
padded_train_logits,
targets,
tf.sequence_mask(targets_sequence_length,
batch_max_targets_sequence_length, dtype=tf.float32),
average_across_timesteps=True,
average_across_batch=False,
softmax_loss_function=_sampled_loss
)
# batch_max_seq_length_eval = tf.shape(eval_logits)[1]
# unpadded_targets_eval = targets[:, :batch_max_seq_length_eval]
eval_loss = tfa.seq2seq.sequence_loss(
padded_eval_logits,
targets,
tf.sequence_mask(targets_sequence_length,
batch_max_targets_sequence_length, dtype=tf.float32),
average_across_timesteps=True,
average_across_batch=False
)
return train_loss, eval_loss