def sample()

in models/context_model.py [0:0]


    def sample(self, audio_code: th.Tensor, argmax: bool = False):
        """
        :param audio_code: B x T x audio_dim Tensor containing the encoded audio for the sequence
        :param argmax: if False, sample from Gumbel softmax; if True use classes with highest probabilities
        :return: B x T x heads x classes Tensor containing one-hot representation of latent code
        """
        assert audio_code.shape[0] == 1
        T = audio_code.shape[1]
        one_hot = th.zeros(1, T, self.heads, self.classes, device=audio_code.device)
        self._reset()
        for t in range(T):
            start, end = max(0, t - self.receptive_field()), t + 1
            context = one_hot[:, start:end, :, :]
            audio = audio_code[:, start:end, :]
            for h in range(self.heads):
                # select input for next logprobs
                logprobs = self._forward_inference(t, h, context, audio)["logprobs"][:, -1, h, :]
                # discretize
                if not argmax:
                    g = -th.log(-th.log(th.clamp(th.rand(logprobs.shape, device=logprobs.device), min=1e-10, max=1)))
                    logprobs = logprobs + g
                label_idx = th.argmax(logprobs, dim=-1).squeeze().item()
                one_hot[:, t, h, label_idx] = 1
        return {"one_hot": one_hot}