in shap_e/models/nn/checkpoint.py [0:0]
def forward(ctx, run_function, length_1, length_2, *args):
ctx.run_function = run_function
ctx.length_1 = length_1
ctx.length_2 = length_2
input_tensors = [x.detach().requires_grad_(True) for x in args[:length_1]]
input_params = list(args[length_1 : length_1 + length_2])
output_grads = list(args[length_1 + length_2 :])
ctx.save_for_backward(*input_tensors, *input_params, *output_grads)
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=True,
)
return input_grads