in shap_e/models/nn/checkpoint.py [0:0]
def backward(ctx, *all_output_grads):
args = ctx.saved_tensors
input_tensors = [x.detach().requires_grad_(True) for x in args[: ctx.length_1]]
input_params = list(args[ctx.length_1 : ctx.length_1 + ctx.length_2])
output_grads = [
x.detach().requires_grad_(True) for x in args[ctx.length_1 + ctx.length_2 :]
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
input_tensors + input_params,
output_grads,
allow_unused=True,
create_graph=True,
retain_graph=True,
)
input_grads_grads = torch.autograd.grad(
input_grads,
input_tensors + input_params + output_grads,
all_output_grads,
allow_unused=True,
)
del input_grads
return (None, None, None) + input_grads_grads