in fairmotion/models/rnn.py [0:0]
def forward(self, src, tgt, max_len=None, teacher_forcing_ratio=0.5):
"""
Inputs:
src, tgt: Tensors of shape (batch_size, seq_len, input_dim)
max_len: Maximum length of sequence to be generated during
inference. Set None during training.
teacher_forcing_ratio: Probability of feeding gold target pose as
decoder input instead of predicted pose from previous time step
"""
# convert src, tgt to (seq_len, batch_size, input_dim) format
src = src.transpose(0, 1)
tgt = tgt.transpose(0, 1)
lstm_input = self.dropout(src)
state = None
# Generate as many poses as in tgt during training
max_len = tgt.shape[0] if max_len is None else max_len
encoder_outputs = torch.zeros(src.shape).to(src.device)
_, state = self.run_lstm(
lstm_input,
encoder_outputs,
state=None,
teacher_forcing_ratio=teacher_forcing_ratio,
)
if self.training:
decoder_outputs = torch.zeros(
max_len - 1, src.shape[1], src.shape[2]
).to(src.device)
tgt = self.dropout(tgt)
decoder_outputs, _ = self.run_lstm(
tgt[:-1],
decoder_outputs,
state=state,
teacher_forcing_ratio=teacher_forcing_ratio,
)
outputs = torch.cat((encoder_outputs, decoder_outputs))
else:
del encoder_outputs
outputs = torch.zeros(max_len, src.shape[1], src.shape[2]).to(
src.device
)
inputs = lstm_input[-1].unsqueeze(0)
outputs, _ = self.run_lstm(
inputs,
outputs,
state=state,
max_len=max_len,
teacher_forcing_ratio=0,
)
return outputs.transpose(0, 1)