def forward_inference()

in models/context_model.py [0:0]


    def forward_inference(self, t: int, h: int, context: th.Tensor, audio: th.Tensor = None):
        """
        :param t: current time step
        :param h: current head
        :param context: B x T x heads x ch_in Tensor
        :param audio: B x T x audio_dim Tensor
        :return: B x T x heads x ch_out Tensor
        """
        B, T = context.shape[0], context.shape[1]
        context = context.view(B, T, -1).permute(0, 2, 1).contiguous()

        if self.historic_t < t:
            self.buffer = self.buffer.to(context.device)
            self.buffer = th.cat([self.buffer, th.zeros(1, self.buffer.shape[1], 1, device=self.buffer.device)], dim=-1)

        # next head from previous head predictions
        y_masked = self.masked_linear.bias[h*self.ch_out:(h+1)*self.ch_out].view(1, -1, 1).clone()
        if h > 0:
            y_masked += F.conv1d(context[:, :h*self.ch_in, -1:],
                                 self.masked_linear.weight[h*self.ch_out:(h+1)*self.ch_out, :h*self.ch_in, :])
        self.buffer[:, h*self.ch_out:(h+1)*self.ch_out, -1:] += y_masked

        # next head from audio
        if audio is not None:
            audio = audio[:, -1:, :]
            audio = audio.permute(0, 2, 1).contiguous()
            y_audio = F.conv1d(audio[:, :, -1:],
                               self.unmasked_linear.weight[h*self.ch_out:(h+1)*self.ch_out, :, :],
                               bias=self.unmasked_linear.bias[h*self.ch_out:(h+1)*self.ch_out])
            self.buffer[:, h*self.ch_out:(h+1)*self.ch_out, -1:] += y_audio

        # historic time steps
        if self.kernel_size > 0 and self.historic_t < t:
            h = context[:, :, -self.receptive_field():-1]
            if h.shape[-1] < self.receptive_field() - 1:
                h = F.pad(h, pad=[self.receptive_field() - 1 - h.shape[-1], 0])
            h = self.historic(h)
            self.buffer[:, :, -1:] += h

        self.historic_t = t

        return self.buffer.permute(0, 2, 1).contiguous().view(1, -1, self.heads, self.ch_out)