in models/context_model.py [0:0]
def sample(self, audio_code: th.Tensor, argmax: bool = False):
"""
:param audio_code: B x T x audio_dim Tensor containing the encoded audio for the sequence
:param argmax: if False, sample from Gumbel softmax; if True use classes with highest probabilities
:return: B x T x heads x classes Tensor containing one-hot representation of latent code
"""
assert audio_code.shape[0] == 1
T = audio_code.shape[1]
one_hot = th.zeros(1, T, self.heads, self.classes, device=audio_code.device)
self._reset()
for t in range(T):
start, end = max(0, t - self.receptive_field()), t + 1
context = one_hot[:, start:end, :, :]
audio = audio_code[:, start:end, :]
for h in range(self.heads):
# select input for next logprobs
logprobs = self._forward_inference(t, h, context, audio)["logprobs"][:, -1, h, :]
# discretize
if not argmax:
g = -th.log(-th.log(th.clamp(th.rand(logprobs.shape, device=logprobs.device), min=1e-10, max=1)))
logprobs = logprobs + g
label_idx = th.argmax(logprobs, dim=-1).squeeze().item()
one_hot[:, t, h, label_idx] = 1
return {"one_hot": one_hot}