in fairmotion/models/transformer.py [0:0]
def forward(self, src, tgt, max_len=None, teacher_forcing_ratio=None):
# Transformer expects src and tgt in format (len, batch_size, dim)
src = src.transpose(0, 1)
tgt = tgt.transpose(0, 1)
# src and tgt are now (T, B, E)
if max_len is None:
max_len = tgt.shape[0]
projected_src = self.encoder(src) * np.sqrt(self.ninp)
pos_encoded_src = self.pos_encoder(projected_src)
encoder_output = self.transformer_encoder(pos_encoded_src)
if self.training:
# Create mask for training
tgt_mask = self._generate_square_subsequent_mask(tgt.shape[0]).to(
device=tgt.device,
)
# Use last source pose as first input to decoder
tgt = torch.cat((src[-1].unsqueeze(0), tgt[:-1]))
pos_encoder_tgt = self.pos_encoder(
self.encoder(tgt) * np.sqrt(self.ninp)
)
output = self.transformer_decoder(
pos_encoder_tgt, encoder_output, tgt_mask=tgt_mask,
)
output = self.project(output)
else:
# greedy decoding
decoder_input = torch.zeros(
max_len, src.shape[1], src.shape[-1],
).type_as(src.data)
next_pose = tgt[0].clone()
# Create mask for greedy encoding across the decoded output
tgt_mask = self._generate_square_subsequent_mask(max_len).to(
device=tgt.device
)
for i in range(max_len):
decoder_input[i] = next_pose
pos_encoded_input = self.pos_encoder(
self.encoder(decoder_input) * np.sqrt(self.ninp)
)
decoder_outputs = self.transformer_decoder(
pos_encoded_input, encoder_output, tgt_mask=tgt_mask,
)
output = self.project(decoder_outputs)
next_pose = output[i].clone()
del output
output = decoder_input
return output.transpose(0, 1)