in models/src/wavenet_vocoder/mixture.py [0:0]
def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True):
"""Mixture of continuous gaussian distributions loss
Note that it is assumed that input is scaled to [-1, 1].
Args:
y_hat (Tensor): Predicted output (B x C x T)
y (Tensor): Target (B x T x 1).
log_scale_min (float): Log scale minimum value
reduce (bool): If True, the losses are averaged or summed for each
minibatch.
Returns
Tensor: loss
"""
assert y_hat.dim() == 3
C = y_hat.size(1)
if C == 2:
nr_mix = 1
else:
assert y_hat.size(1) % 3 == 0
nr_mix = y_hat.size(1) // 3
# (B x T x C)
y_hat = y_hat.transpose(1, 2)
# unpack parameters.
if C == 2:
# special case for C == 2, just for compatibility
logit_probs = None
means = y_hat[:, :, 0:1]
log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min)
else:
# (B, T, num_mixtures) x 3
logit_probs = y_hat[:, :, :nr_mix]
means = y_hat[:, :, nr_mix : 2 * nr_mix]
log_scales = torch.clamp(
y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min
)
# B x T x 1 -> B x T x num_mixtures
y = y.expand_as(means)
centered_y = y - means
dist = Normal(loc=0.0, scale=torch.exp(log_scales))
# do we need to add a trick to avoid log(0)?
log_probs = dist.log_prob(centered_y)
if nr_mix > 1:
log_probs = log_probs + F.log_softmax(logit_probs, -1)
if reduce:
if nr_mix == 1:
return -torch.sum(log_probs)
else:
return -torch.sum(log_sum_exp(log_probs))
else:
if nr_mix == 1:
return -log_probs
else:
return -log_sum_exp(log_probs).unsqueeze(-1)