def decode()

in main/model.py [0:0]


    def decode(self, joint_euler):
        batch_size = joint_euler.shape[0]
        mesh = self.get_mesh_data(batch_size)
        joint_rot_mat = euler2mat(joint_euler, to_4x4=True)

        # estimate skeleton corrective
        skeleton_corrective = self.skeleton_refine_net(self.id_code)
        mesh['local_pose_refined'] = mesh['local_pose'].clone()
        mesh['local_pose_refined'][:,:3,3] += skeleton_corrective
        mesh['global_pose_refined'] = [None for _ in range(self.joint_num)]
        mesh['global_pose_refined'][self.root_joint_idx] = mesh['global_pose'][self.root_joint_idx].clone()
        forward_kinematics(self.skeleton, self.root_joint_idx, mesh['local_pose_refined'], mesh['global_pose_refined'])
        mesh['global_pose_refined'] = torch.stack(mesh['global_pose_refined'])

        # rigid transform for root joint
        global_pose = [[None for _ in range(self.joint_num)] for _ in range(batch_size)]
        for i in range(batch_size):
            global_pose[i][self.root_joint_idx] = torch.eye(4).float().cuda() # use identity matrix in testing stage

        # forward kinematics
        joint_out = []; joint_trans_mat = [];
        for i in range(batch_size):
            forward_kinematics(self.skeleton, self.root_joint_idx, torch.bmm(mesh['local_pose_refined'], joint_rot_mat[i]), global_pose[i])
            joint_out.append(torch.cat([global_pose[i][j][None,:3,3] for j in range(self.joint_num)],0))
            joint_trans_mat.append(torch.cat([torch.mm(global_pose[i][j], mesh['global_pose_inv'][j,:,:])[None,:,:] for j in range(self.joint_num)]))
        joint_out = torch.cat(joint_out).view(batch_size,self.joint_num,3)
        joint_trans_mat = torch.cat(joint_trans_mat).view(batch_size,self.joint_num,4,4).permute(1,0,2,3)
        
        # estimate corrective vector
        pose_corrective, id_corrective = self.skin_refine_net(joint_euler.detach(), self.id_code[None,:].repeat(batch_size,1))
        mesh_refined_xyz = mesh['v'] + pose_corrective + id_corrective

        # LBS
        mesh_refined_xyz1 = torch.cat([mesh_refined_xyz, torch.ones_like(mesh_refined_xyz[:,:,:1])],2)
        mesh_out_refined = sum([mesh['skinning_weight'][:,:,j,None]*torch.bmm(joint_trans_mat[j],mesh_refined_xyz1.permute(0,2,1)).permute(0,2,1)[:,:,:3] for j in range(self.joint_num)])

        out = {}
        out['joint_out'] = joint_out
        out['mesh_out_refined'] = mesh_out_refined
        return out