in models/context_model.py [0:0]
def forward_inference(self, t: int, h: int, context: th.Tensor, audio: th.Tensor = None):
"""
:param t: current time step
:param h: current head
:param context: B x T x heads x ch_in Tensor
:param audio: B x T x audio_dim Tensor
:return: B x T x heads x ch_out Tensor
"""
B, T = context.shape[0], context.shape[1]
context = context.view(B, T, -1).permute(0, 2, 1).contiguous()
if self.historic_t < t:
self.buffer = self.buffer.to(context.device)
self.buffer = th.cat([self.buffer, th.zeros(1, self.buffer.shape[1], 1, device=self.buffer.device)], dim=-1)
# next head from previous head predictions
y_masked = self.masked_linear.bias[h*self.ch_out:(h+1)*self.ch_out].view(1, -1, 1).clone()
if h > 0:
y_masked += F.conv1d(context[:, :h*self.ch_in, -1:],
self.masked_linear.weight[h*self.ch_out:(h+1)*self.ch_out, :h*self.ch_in, :])
self.buffer[:, h*self.ch_out:(h+1)*self.ch_out, -1:] += y_masked
# next head from audio
if audio is not None:
audio = audio[:, -1:, :]
audio = audio.permute(0, 2, 1).contiguous()
y_audio = F.conv1d(audio[:, :, -1:],
self.unmasked_linear.weight[h*self.ch_out:(h+1)*self.ch_out, :, :],
bias=self.unmasked_linear.bias[h*self.ch_out:(h+1)*self.ch_out])
self.buffer[:, h*self.ch_out:(h+1)*self.ch_out, -1:] += y_audio
# historic time steps
if self.kernel_size > 0 and self.historic_t < t:
h = context[:, :, -self.receptive_field():-1]
if h.shape[-1] < self.receptive_field() - 1:
h = F.pad(h, pad=[self.receptive_field() - 1 - h.shape[-1], 0])
h = self.historic(h)
self.buffer[:, :, -1:] += h
self.historic_t = t
return self.buffer.permute(0, 2, 1).contiguous().view(1, -1, self.heads, self.ch_out)