in models/context_model.py [0:0]
def _forward_inference(self, t: int, h: int, context: th.Tensor, audio: th.Tensor):
"""
:param t: current time step
:param h: current head
:param context: B x T x heads x classes Tensor
:param audio: B x T x audio_dim Tensor
:return: logprobs: B x T x heads x classes Tensor containing log probabilities for each class
probs: B x T x heads x classes Tensor containing probabilities for each class
labels: B x T x heads LongTensor containing discretized class labels
"""
x = self.embedding.forward_inference(t, h, context)
for layer in self.context_layers:
x = layer.forward_inference(t, h, x, audio)
x = F.leaky_relu(x, 0.2)
logits = self.logits.forward_inference(t, h, x)
logprobs = F.log_softmax(logits, dim=-1)
probs = F.softmax(logprobs, dim=-1)
labels = th.argmax(logprobs, dim=-1)
return {"logprobs": logprobs, "probs": probs, "labels": labels}