in models/src/wavenet_vocoder/mixture.py [0:0]
def sample_from_mix_gaussian(y, log_scale_min=-7.0):
"""
Sample from (discretized) mixture of gaussian distributions
Args:
y (Tensor): B x C x T
log_scale_min (float): Log scale minimum value
Returns:
Tensor: sample in range of [-1, 1].
"""
C = y.size(1)
if C == 2:
nr_mix = 1
else:
assert y.size(1) % 3 == 0
nr_mix = y.size(1) // 3
# B x T x C
y = y.transpose(1, 2)
if C == 2:
logit_probs = None
else:
logit_probs = y[:, :, :nr_mix]
if nr_mix > 1:
# sample mixture indicator from softmax
temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
temp = logit_probs.data - torch.log(-torch.log(temp))
_, argmax = temp.max(dim=-1)
# (B, T) -> (B, T, nr_mix)
one_hot = to_one_hot(argmax, nr_mix)
# Select means and log scales
means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
else:
if C == 2:
means, log_scales = y[:, :, 0], y[:, :, 1]
elif C == 3:
means, log_scales = y[:, :, 1], y[:, :, 2]
scales = torch.exp(log_scales)
dist = Normal(loc=means, scale=scales)
x = dist.sample()
x = torch.clamp(x, min=-1.0, max=1.0)
return x