def forward()

in models/context_model.py [0:0]


    def forward(self, expression_one_hot: th.Tensor, audio_code: th.Tensor):
        """
        :param expression_one_hot: B x T x heads x classes Tensor containing one hot representation of previous labels
               audio_code: B x T x audio_dim Tensor containing the audio embedding
        :return: logprobs: B x T x heads x C Tensor containing log probabilities for each class
                 probs:  B x T x heads x C Tensor containing probabilities for each class
                 labels: B x T x heads LongTensor containing label indices
        """

        x = self.embedding(expression_one_hot)

        for layer in self.context_layers:
            x = layer(x, audio_code)
            x = F.leaky_relu(x, 0.2)

        logits = self.logits(x)
        logprobs = F.log_softmax(logits, dim=-1)
        probs = F.softmax(logprobs, dim=-1)
        labels = th.argmax(logprobs, dim=-1)

        return {"logprobs": logprobs, "probs": probs, "labels": labels}