in models/context_model.py [0:0]
def forward(self, context: th.Tensor, audio: th.Tensor = None):
"""
:param context: B x T x heads x ch_in Tensor
:param audio: B x T x audio_dim Tensor
:return: B x T x heads x ch_out Tensor
"""
B, T = context.shape[0], context.shape[1]
context = context.view(B, T, -1).permute(0, 2, 1).contiguous()
# current context time step: masked along head axis
y = F.conv1d(context, self.masked_linear.weight * self.mask, bias=self.masked_linear.bias)
# current audio time step: no masking
if audio is not None:
audio = audio.permute(0, 2, 1).contiguous()
audio = self.unmasked_linear(audio)
y = y + audio
# historic time steps
if self.kernel_size > 0:
h = F.pad(context[:, :, :-1], [self.dilation * (self.kernel_size - 1) + 1, 0])
y = y + self.historic(h)
y = y.permute(0, 2, 1).contiguous().view(B, T, self.heads, self.ch_out)
return y