def forward()

in common/quaternet.py [0:0]


    def forward(self, x, h=None, return_prenorm=False, return_all=False):
        """
        Run a forward pass of this model.
        Arguments:
         -- x: input tensor of shape (N, L, J*4 + O + C), where N is the batch size, L is the sequence length,
               J is the number of joints, O is the number of outputs, and C is the number of controls.
               Features must be provided in the order J, O, C.
         -- h: hidden state. If None, it defaults to the learned initial state.
         -- return_prenorm: if True, return the quaternions prior to normalization.
         -- return_all: if True, return all L frames, otherwise return only the last frame. If only the latter
                        is wanted (e.g. when conditioning the model with an initialization sequence), this
                        argument should be left to False as it avoids unnecessary computation.
        """
        assert len(x.shape) == 3
        assert x.shape[-1] == self.num_joints*4 + self.num_outputs + self.num_controls
        
        x_orig = x
        if self.num_controls > 0:
            controls = x[:, :, self.num_joints*4+self.num_outputs:]
            controls = self.relu(self.fc1(controls))
            controls = self.relu(self.fc2(controls))
            x = torch.cat((x[:, :, :self.num_joints*4+self.num_outputs], controls), dim=2)
        
        if h is None:
            h = self.h0.expand(-1, x.shape[0], -1).contiguous()
        x, h = self.rnn(x, h)
        if return_all:
            x = self.fc(x)
        else:
            x = self.fc(x[:, -1:])
            x_orig = x_orig[:, -1:]
        
        pre_normalized = x[:, :, :self.num_joints*4].contiguous()
        normalized = pre_normalized.view(-1, 4)
        if self.model_velocities:
            normalized = qmul(normalized, x_orig[:, :, :self.num_joints*4].contiguous().view(-1, 4))
        normalized = F.normalize(normalized, dim=1).view(pre_normalized.shape)
        
        if self.num_outputs > 0:
            x = torch.cat((normalized, x[:, :, self.num_joints*4:]), dim=2)
        else:
            x = normalized
        
        if return_prenorm:
            return x, h, pre_normalized
        else:
            return x, h