def _compute()

in theseus/utils/utils.py [0:0]


    def _compute(group_idx):
        dof = group_args[group_idx].dof()
        function_dim_ = function_dim or dof
        jac = torch.zeros(
            batch_size, function_dim_, dof, dtype=group_args[group_idx].dtype
        )
        for d in range(dof):
            delta = torch.zeros(1, dof).to(
                device=group_args[0].device, dtype=group_args[group_idx].dtype
            )
            delta[:, d] = delta_mag

            group_plus = group_args[group_idx].retract(delta)
            group_minus = group_args[group_idx].retract(-delta)
            group_plus_args = [g for g in group_args]
            group_plus_args[group_idx] = group_plus
            group_minus_args = [g for g in group_args]
            group_minus_args[group_idx] = group_minus

            diff = function(group_minus_args).local(function(group_plus_args))
            jac[:, :, d] = diff / (2 * delta_mag)
        return jac