in theseus/utils/utils.py [0:0]
def _compute(group_idx):
dof = group_args[group_idx].dof()
function_dim_ = function_dim or dof
jac = torch.zeros(
batch_size, function_dim_, dof, dtype=group_args[group_idx].dtype
)
for d in range(dof):
delta = torch.zeros(1, dof).to(
device=group_args[0].device, dtype=group_args[group_idx].dtype
)
delta[:, d] = delta_mag
group_plus = group_args[group_idx].retract(delta)
group_minus = group_args[group_idx].retract(-delta)
group_plus_args = [g for g in group_args]
group_plus_args[group_idx] = group_plus
group_minus_args = [g for g in group_args]
group_minus_args[group_idx] = group_minus
diff = function(group_minus_args).local(function(group_plus_args))
jac[:, :, d] = diff / (2 * delta_mag)
return jac