in models/context_model.py [0:0]
def forward(self, expression_one_hot: th.Tensor, audio_code: th.Tensor):
"""
:param expression_one_hot: B x T x heads x classes Tensor containing one hot representation of previous labels
audio_code: B x T x audio_dim Tensor containing the audio embedding
:return: logprobs: B x T x heads x C Tensor containing log probabilities for each class
probs: B x T x heads x C Tensor containing probabilities for each class
labels: B x T x heads LongTensor containing label indices
"""
x = self.embedding(expression_one_hot)
for layer in self.context_layers:
x = layer(x, audio_code)
x = F.leaky_relu(x, 0.2)
logits = self.logits(x)
logprobs = F.log_softmax(logits, dim=-1)
probs = F.softmax(logprobs, dim=-1)
labels = th.argmax(logprobs, dim=-1)
return {"logprobs": logprobs, "probs": probs, "labels": labels}