def forward()

in fairmotion/models/transformer.py [0:0]


    def forward(self, src, tgt, max_len=None, teacher_forcing_ratio=None):
        # Transformer expects src and tgt in format (len, batch_size, dim)
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)

        # src and tgt are now (T, B, E)
        if max_len is None:
            max_len = tgt.shape[0]

        projected_src = self.encoder(src) * np.sqrt(self.ninp)
        pos_encoded_src = self.pos_encoder(projected_src)
        encoder_output = self.transformer_encoder(pos_encoded_src)

        if self.training:

            # Create mask for training
            tgt_mask = self._generate_square_subsequent_mask(tgt.shape[0]).to(
                device=tgt.device,
            )

            # Use last source pose as first input to decoder
            tgt = torch.cat((src[-1].unsqueeze(0), tgt[:-1]))
            pos_encoder_tgt = self.pos_encoder(
                self.encoder(tgt) * np.sqrt(self.ninp)
            )
            output = self.transformer_decoder(
                pos_encoder_tgt, encoder_output, tgt_mask=tgt_mask,
            )
            output = self.project(output)
        else:
            # greedy decoding
            decoder_input = torch.zeros(
                max_len, src.shape[1], src.shape[-1],
            ).type_as(src.data)
            next_pose = tgt[0].clone()

            # Create mask for greedy encoding across the decoded output
            tgt_mask = self._generate_square_subsequent_mask(max_len).to(
                device=tgt.device
            )
            
            for i in range(max_len):
                decoder_input[i] = next_pose
                pos_encoded_input = self.pos_encoder(
                    self.encoder(decoder_input) * np.sqrt(self.ninp)
                )
                decoder_outputs = self.transformer_decoder(
                    pos_encoded_input, encoder_output, tgt_mask=tgt_mask,
                )
                output = self.project(decoder_outputs)
                next_pose = output[i].clone()
                del output
            output = decoder_input
        return output.transpose(0, 1)